pydantic-ai-slim 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -61,6 +61,7 @@ KnownModelName = Literal[
61
61
  'mistral:codestral-latest',
62
62
  'mistral:mistral-moderation-latest',
63
63
  'ollama:codellama',
64
+ 'ollama:deepseek-r1',
64
65
  'ollama:gemma',
65
66
  'ollama:gemma2',
66
67
  'ollama:llama3',
@@ -81,6 +82,22 @@ KnownModelName = Literal[
81
82
  'anthropic:claude-3-5-haiku-latest',
82
83
  'anthropic:claude-3-5-sonnet-latest',
83
84
  'anthropic:claude-3-opus-latest',
85
+ 'claude-3-5-haiku-latest',
86
+ 'claude-3-5-sonnet-latest',
87
+ 'claude-3-opus-latest',
88
+ 'cohere:c4ai-aya-expanse-32b',
89
+ 'cohere:c4ai-aya-expanse-8b',
90
+ 'cohere:command',
91
+ 'cohere:command-light',
92
+ 'cohere:command-light-nightly',
93
+ 'cohere:command-nightly',
94
+ 'cohere:command-r',
95
+ 'cohere:command-r-03-2024',
96
+ 'cohere:command-r-08-2024',
97
+ 'cohere:command-r-plus',
98
+ 'cohere:command-r-plus-04-2024',
99
+ 'cohere:command-r-plus-08-2024',
100
+ 'cohere:command-r7b-12-2024',
84
101
  'test',
85
102
  ]
86
103
  """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -145,6 +162,7 @@ class AgentModel(ABC):
145
162
  class StreamedResponse(ABC):
146
163
  """Streamed response from an LLM when calling a tool."""
147
164
 
165
+ _model_name: str
148
166
  _usage: Usage = field(default_factory=Usage, init=False)
149
167
  _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
150
168
  _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
@@ -168,7 +186,13 @@ class StreamedResponse(ABC):
168
186
 
169
187
  def get(self) -> ModelResponse:
170
188
  """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
171
- return ModelResponse(parts=self._parts_manager.get_parts(), timestamp=self.timestamp())
189
+ return ModelResponse(
190
+ parts=self._parts_manager.get_parts(), model_name=self._model_name, timestamp=self.timestamp()
191
+ )
192
+
193
+ def model_name(self) -> str:
194
+ """Get the model name of the response."""
195
+ return self._model_name
172
196
 
173
197
  def usage(self) -> Usage:
174
198
  """Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
@@ -228,6 +252,10 @@ def infer_model(model: Model | KnownModelName) -> Model:
228
252
  from .test import TestModel
229
253
 
230
254
  return TestModel()
255
+ elif model.startswith('cohere:'):
256
+ from .cohere import CohereModel
257
+
258
+ return CohereModel(model[7:])
231
259
  elif model.startswith('openai:'):
232
260
  from .openai import OpenAIModel
233
261
 
@@ -1,14 +1,16 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterator
3
+ from collections.abc import AsyncIterable, AsyncIterator
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
+ from datetime import datetime, timezone
7
+ from json import JSONDecodeError, loads as json_loads
6
8
  from typing import Any, Literal, Union, cast, overload
7
9
 
8
10
  from httpx import AsyncClient as AsyncHTTPClient
9
11
  from typing_extensions import assert_never
10
12
 
11
- from .. import usage
13
+ from .. import UnexpectedModelBehavior, _utils, usage
12
14
  from .._utils import guard_tool_call_id as _guard_tool_call_id
13
15
  from ..messages import (
14
16
  ArgsDict,
@@ -16,6 +18,7 @@ from ..messages import (
16
18
  ModelRequest,
17
19
  ModelResponse,
18
20
  ModelResponsePart,
21
+ ModelResponseStreamEvent,
19
22
  RetryPromptPart,
20
23
  SystemPromptPart,
21
24
  TextPart,
@@ -38,11 +41,16 @@ try:
38
41
  from anthropic.types import (
39
42
  Message as AnthropicMessage,
40
43
  MessageParam,
44
+ RawContentBlockDeltaEvent,
45
+ RawContentBlockStartEvent,
46
+ RawContentBlockStopEvent,
41
47
  RawMessageDeltaEvent,
42
48
  RawMessageStartEvent,
49
+ RawMessageStopEvent,
43
50
  RawMessageStreamEvent,
44
51
  TextBlock,
45
52
  TextBlockParam,
53
+ TextDelta,
46
54
  ToolChoiceParam,
47
55
  ToolParam,
48
56
  ToolResultBlockParam,
@@ -152,7 +160,7 @@ class AnthropicAgentModel(AgentModel):
152
160
  """Implementation of `AgentModel` for Anthropic models."""
153
161
 
154
162
  client: AsyncAnthropic
155
- model_name: str
163
+ model_name: AnthropicModelName
156
164
  allow_text_result: bool
157
165
  tools: list[ToolParam]
158
166
 
@@ -186,16 +194,22 @@ class AnthropicAgentModel(AgentModel):
186
194
  self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
187
195
  ) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
188
196
  # standalone function to make it easier to override
197
+ model_settings = model_settings or {}
198
+
199
+ tool_choice: ToolChoiceParam | None
200
+
189
201
  if not self.tools:
190
- tool_choice: ToolChoiceParam | None = None
191
- elif not self.allow_text_result:
192
- tool_choice = {'type': 'any'}
202
+ tool_choice = None
193
203
  else:
194
- tool_choice = {'type': 'auto'}
204
+ if not self.allow_text_result:
205
+ tool_choice = {'type': 'any'}
206
+ else:
207
+ tool_choice = {'type': 'auto'}
195
208
 
196
- system_prompt, anthropic_messages = self._map_message(messages)
209
+ if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
210
+ tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
197
211
 
198
- model_settings = model_settings or {}
212
+ system_prompt, anthropic_messages = self._map_message(messages)
199
213
 
200
214
  return await self.client.messages.create(
201
215
  max_tokens=model_settings.get('max_tokens', 1024),
@@ -210,8 +224,7 @@ class AnthropicAgentModel(AgentModel):
210
224
  timeout=model_settings.get('timeout', NOT_GIVEN),
211
225
  )
212
226
 
213
- @staticmethod
214
- def _process_response(response: AnthropicMessage) -> ModelResponse:
227
+ def _process_response(self, response: AnthropicMessage) -> ModelResponse:
215
228
  """Process a non-streamed response, and prepare a message to return."""
216
229
  items: list[ModelResponsePart] = []
217
230
  for item in response.content:
@@ -227,26 +240,17 @@ class AnthropicAgentModel(AgentModel):
227
240
  )
228
241
  )
229
242
 
230
- return ModelResponse(items)
243
+ return ModelResponse(items, model_name=self.model_name)
231
244
 
232
- @staticmethod
233
- async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
234
- """TODO: Process a streamed response, and prepare a streaming response to return."""
235
- # We don't yet support streamed responses from Anthropic, so we raise an error here for now.
236
- # Streamed responses will be supported in a future release.
237
-
238
- raise RuntimeError('Streamed responses are not yet supported for Anthropic models.')
239
-
240
- # Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamedResponse
241
- # depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
242
- # RawMessageStartEvent
243
- # RawMessageDeltaEvent
244
- # RawMessageStopEvent
245
- # RawContentBlockStartEvent
246
- # RawContentBlockDeltaEvent
247
- # RawContentBlockDeltaEvent
248
- #
249
- # We might refactor streaming internally before we implement this...
245
+ async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
246
+ peekable_response = _utils.PeekableAsyncStream(response)
247
+ first_chunk = await peekable_response.peek()
248
+ if isinstance(first_chunk, _utils.Unset):
249
+ raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
250
+
251
+ # Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
252
+ timestamp = datetime.now(tz=timezone.utc)
253
+ return AnthropicStreamedResponse(_model_name=self.model_name, _response=peekable_response, _timestamp=timestamp)
250
254
 
251
255
  @staticmethod
252
256
  def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
@@ -342,3 +346,63 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage
342
346
  response_tokens=response_usage.output_tokens,
343
347
  total_tokens=(request_tokens or 0) + response_usage.output_tokens,
344
348
  )
349
+
350
+
351
+ @dataclass
352
+ class AnthropicStreamedResponse(StreamedResponse):
353
+ """Implementation of `StreamedResponse` for Anthropic models."""
354
+
355
+ _response: AsyncIterable[RawMessageStreamEvent]
356
+ _timestamp: datetime
357
+
358
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
359
+ current_block: TextBlock | ToolUseBlock | None = None
360
+ current_json: str = ''
361
+
362
+ async for event in self._response:
363
+ self._usage += _map_usage(event)
364
+
365
+ if isinstance(event, RawContentBlockStartEvent):
366
+ current_block = event.content_block
367
+ if isinstance(current_block, TextBlock) and current_block.text:
368
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=current_block.text)
369
+ elif isinstance(current_block, ToolUseBlock):
370
+ maybe_event = self._parts_manager.handle_tool_call_delta(
371
+ vendor_part_id=current_block.id,
372
+ tool_name=current_block.name,
373
+ args=cast(dict[str, Any], current_block.input),
374
+ tool_call_id=current_block.id,
375
+ )
376
+ if maybe_event is not None:
377
+ yield maybe_event
378
+
379
+ elif isinstance(event, RawContentBlockDeltaEvent):
380
+ if isinstance(event.delta, TextDelta):
381
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=event.delta.text)
382
+ elif (
383
+ current_block and event.delta.type == 'input_json_delta' and isinstance(current_block, ToolUseBlock)
384
+ ):
385
+ # Try to parse the JSON immediately, otherwise cache the value for later. This handles
386
+ # cases where the JSON is not currently valid but will be valid once we stream more tokens.
387
+ try:
388
+ parsed_args = json_loads(current_json + event.delta.partial_json)
389
+ current_json = ''
390
+ except JSONDecodeError:
391
+ current_json += event.delta.partial_json
392
+ continue
393
+
394
+ # For tool calls, we need to handle partial JSON updates
395
+ maybe_event = self._parts_manager.handle_tool_call_delta(
396
+ vendor_part_id=current_block.id,
397
+ tool_name='',
398
+ args=parsed_args,
399
+ tool_call_id=current_block.id,
400
+ )
401
+ if maybe_event is not None:
402
+ yield maybe_event
403
+
404
+ elif isinstance(event, (RawContentBlockStopEvent, RawMessageStopEvent)):
405
+ current_block = None
406
+
407
+ def timestamp(self) -> datetime:
408
+ return self._timestamp
@@ -0,0 +1,278 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from collections.abc import Iterable
4
+ from dataclasses import dataclass, field
5
+ from itertools import chain
6
+ from typing import Literal, TypeAlias, Union
7
+
8
+ from cohere import TextAssistantMessageContentItem
9
+ from typing_extensions import assert_never
10
+
11
+ from .. import result
12
+ from .._utils import guard_tool_call_id as _guard_tool_call_id
13
+ from ..messages import (
14
+ ModelMessage,
15
+ ModelRequest,
16
+ ModelResponse,
17
+ ModelResponsePart,
18
+ RetryPromptPart,
19
+ SystemPromptPart,
20
+ TextPart,
21
+ ToolCallPart,
22
+ ToolReturnPart,
23
+ UserPromptPart,
24
+ )
25
+ from ..settings import ModelSettings
26
+ from ..tools import ToolDefinition
27
+ from . import (
28
+ AgentModel,
29
+ Model,
30
+ check_allow_model_requests,
31
+ )
32
+
33
+ try:
34
+ from cohere import (
35
+ AssistantChatMessageV2,
36
+ AsyncClientV2,
37
+ ChatMessageV2,
38
+ ChatResponse,
39
+ SystemChatMessageV2,
40
+ ToolCallV2,
41
+ ToolCallV2Function,
42
+ ToolChatMessageV2,
43
+ ToolV2,
44
+ ToolV2Function,
45
+ UserChatMessageV2,
46
+ )
47
+ from cohere.v2.client import OMIT
48
+ except ImportError as _import_error:
49
+ raise ImportError(
50
+ 'Please install `cohere` to use the Cohere model, '
51
+ "you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
52
+ ) from _import_error
53
+
54
+ CohereModelName: TypeAlias = Union[
55
+ str,
56
+ Literal[
57
+ 'c4ai-aya-expanse-32b',
58
+ 'c4ai-aya-expanse-8b',
59
+ 'command',
60
+ 'command-light',
61
+ 'command-light-nightly',
62
+ 'command-nightly',
63
+ 'command-r',
64
+ 'command-r-03-2024',
65
+ 'command-r-08-2024',
66
+ 'command-r-plus',
67
+ 'command-r-plus-04-2024',
68
+ 'command-r-plus-08-2024',
69
+ 'command-r7b-12-2024',
70
+ ],
71
+ ]
72
+
73
+
74
+ @dataclass(init=False)
75
+ class CohereModel(Model):
76
+ """A model that uses the Cohere API.
77
+
78
+ Internally, this uses the [Cohere Python client](
79
+ https://github.com/cohere-ai/cohere-python) to interact with the API.
80
+
81
+ Apart from `__init__`, all methods are private or match those of the base class.
82
+ """
83
+
84
+ model_name: CohereModelName
85
+ client: AsyncClientV2 = field(repr=False)
86
+
87
+ def __init__(
88
+ self,
89
+ model_name: CohereModelName,
90
+ *,
91
+ api_key: str | None = None,
92
+ cohere_client: AsyncClientV2 | None = None,
93
+ ):
94
+ """Initialize an Cohere model.
95
+
96
+ Args:
97
+ model_name: The name of the Cohere model to use. List of model names
98
+ available [here](https://docs.cohere.com/docs/models#command).
99
+ api_key: The API key to use for authentication, if not provided, the
100
+ `COHERE_API_KEY` environment variable will be used if available.
101
+ cohere_client: An existing Cohere async client to use. If provided,
102
+ `api_key` must be `None`.
103
+ """
104
+ self.model_name: CohereModelName = model_name
105
+ if cohere_client is not None:
106
+ assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
107
+ self.client = cohere_client
108
+ else:
109
+ self.client = AsyncClientV2(api_key=api_key) # type: ignore
110
+
111
+ async def agent_model(
112
+ self,
113
+ *,
114
+ function_tools: list[ToolDefinition],
115
+ allow_text_result: bool,
116
+ result_tools: list[ToolDefinition],
117
+ ) -> AgentModel:
118
+ check_allow_model_requests()
119
+ tools = [self._map_tool_definition(r) for r in function_tools]
120
+ if result_tools:
121
+ tools += [self._map_tool_definition(r) for r in result_tools]
122
+ return CohereAgentModel(
123
+ self.client,
124
+ self.model_name,
125
+ allow_text_result,
126
+ tools,
127
+ )
128
+
129
+ def name(self) -> str:
130
+ return f'cohere:{self.model_name}'
131
+
132
+ @staticmethod
133
+ def _map_tool_definition(f: ToolDefinition) -> ToolV2:
134
+ return ToolV2(
135
+ type='function',
136
+ function=ToolV2Function(
137
+ name=f.name,
138
+ description=f.description,
139
+ parameters=f.parameters_json_schema,
140
+ ),
141
+ )
142
+
143
+
144
+ @dataclass
145
+ class CohereAgentModel(AgentModel):
146
+ """Implementation of `AgentModel` for Cohere models."""
147
+
148
+ client: AsyncClientV2
149
+ model_name: CohereModelName
150
+ allow_text_result: bool
151
+ tools: list[ToolV2]
152
+
153
+ async def request(
154
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
155
+ ) -> tuple[ModelResponse, result.Usage]:
156
+ response = await self._chat(messages, model_settings)
157
+ return self._process_response(response), _map_usage(response)
158
+
159
+ async def _chat(
160
+ self,
161
+ messages: list[ModelMessage],
162
+ model_settings: ModelSettings | None,
163
+ ) -> ChatResponse:
164
+ cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
165
+ model_settings = model_settings or {}
166
+ return await self.client.chat(
167
+ model=self.model_name,
168
+ messages=cohere_messages,
169
+ tools=self.tools or OMIT,
170
+ max_tokens=model_settings.get('max_tokens', OMIT),
171
+ temperature=model_settings.get('temperature', OMIT),
172
+ p=model_settings.get('top_p', OMIT),
173
+ )
174
+
175
+ def _process_response(self, response: ChatResponse) -> ModelResponse:
176
+ """Process a non-streamed response, and prepare a message to return."""
177
+ parts: list[ModelResponsePart] = []
178
+ if response.message.content is not None and len(response.message.content) > 0:
179
+ # While Cohere's API returns a list, it only does that for future proofing
180
+ # and currently only one item is being returned.
181
+ choice = response.message.content[0]
182
+ parts.append(TextPart(choice.text))
183
+ for c in response.message.tool_calls or []:
184
+ if c.function and c.function.name and c.function.arguments:
185
+ parts.append(
186
+ ToolCallPart.from_raw_args(
187
+ tool_name=c.function.name,
188
+ args=c.function.arguments,
189
+ tool_call_id=c.id,
190
+ )
191
+ )
192
+ return ModelResponse(parts=parts, model_name=self.model_name)
193
+
194
+ @classmethod
195
+ def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]:
196
+ """Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
197
+ if isinstance(message, ModelRequest):
198
+ yield from cls._map_user_message(message)
199
+ elif isinstance(message, ModelResponse):
200
+ texts: list[str] = []
201
+ tool_calls: list[ToolCallV2] = []
202
+ for item in message.parts:
203
+ if isinstance(item, TextPart):
204
+ texts.append(item.content)
205
+ elif isinstance(item, ToolCallPart):
206
+ tool_calls.append(_map_tool_call(item))
207
+ else:
208
+ assert_never(item)
209
+ message_param = AssistantChatMessageV2(role='assistant')
210
+ if texts:
211
+ message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))]
212
+ if tool_calls:
213
+ message_param.tool_calls = tool_calls
214
+ yield message_param
215
+ else:
216
+ assert_never(message)
217
+
218
+ @classmethod
219
+ def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
220
+ for part in message.parts:
221
+ if isinstance(part, SystemPromptPart):
222
+ yield SystemChatMessageV2(role='system', content=part.content)
223
+ elif isinstance(part, UserPromptPart):
224
+ yield UserChatMessageV2(role='user', content=part.content)
225
+ elif isinstance(part, ToolReturnPart):
226
+ yield ToolChatMessageV2(
227
+ role='tool',
228
+ tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
229
+ content=part.model_response_str(),
230
+ )
231
+ elif isinstance(part, RetryPromptPart):
232
+ if part.tool_name is None:
233
+ yield UserChatMessageV2(role='user', content=part.model_response())
234
+ else:
235
+ yield ToolChatMessageV2(
236
+ role='tool',
237
+ tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
238
+ content=part.model_response(),
239
+ )
240
+ else:
241
+ assert_never(part)
242
+
243
+
244
+ def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
245
+ return ToolCallV2(
246
+ id=_guard_tool_call_id(t=t, model_source='Cohere'),
247
+ type='function',
248
+ function=ToolCallV2Function(
249
+ name=t.tool_name,
250
+ arguments=t.args_as_json_str(),
251
+ ),
252
+ )
253
+
254
+
255
+ def _map_usage(response: ChatResponse) -> result.Usage:
256
+ usage = response.usage
257
+ if usage is None:
258
+ return result.Usage()
259
+ else:
260
+ details: dict[str, int] = {}
261
+ if usage.billed_units is not None:
262
+ if usage.billed_units.input_tokens:
263
+ details['input_tokens'] = int(usage.billed_units.input_tokens)
264
+ if usage.billed_units.output_tokens:
265
+ details['output_tokens'] = int(usage.billed_units.output_tokens)
266
+ if usage.billed_units.search_units:
267
+ details['search_units'] = int(usage.billed_units.search_units)
268
+ if usage.billed_units.classifications:
269
+ details['classifications'] = int(usage.billed_units.classifications)
270
+
271
+ request_tokens = int(usage.tokens.input_tokens) if usage.tokens and usage.tokens.input_tokens else None
272
+ response_tokens = int(usage.tokens.output_tokens) if usage.tokens and usage.tokens.output_tokens else None
273
+ return result.Usage(
274
+ request_tokens=request_tokens,
275
+ response_tokens=response_tokens,
276
+ total_tokens=(request_tokens or 0) + (response_tokens or 0),
277
+ details=details,
278
+ )
@@ -71,16 +71,15 @@ class FunctionModel(Model):
71
71
  result_tools: list[ToolDefinition],
72
72
  ) -> AgentModel:
73
73
  return FunctionAgentModel(
74
- self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools, None)
74
+ self.function,
75
+ self.stream_function,
76
+ AgentInfo(function_tools, allow_text_result, result_tools, None),
75
77
  )
76
78
 
77
79
  def name(self) -> str:
78
- labels: list[str] = []
79
- if self.function is not None:
80
- labels.append(self.function.__name__)
81
- if self.stream_function is not None:
82
- labels.append(f'stream-{self.stream_function.__name__}')
83
- return f'function:{",".join(labels)}'
80
+ function_name = self.function.__name__ if self.function is not None else ''
81
+ stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
82
+ return f'function:{function_name}:{stream_function_name}'
84
83
 
85
84
 
86
85
  @dataclass(frozen=True)
@@ -147,12 +146,15 @@ class FunctionAgentModel(AgentModel):
147
146
  agent_info = replace(self.agent_info, model_settings=model_settings)
148
147
 
149
148
  assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
149
+ model_name = f'function:{self.function.__name__}'
150
+
150
151
  if inspect.iscoroutinefunction(self.function):
151
152
  response = await self.function(messages, agent_info)
152
153
  else:
153
154
  response_ = await _utils.run_in_executor(self.function, messages, agent_info)
154
155
  assert isinstance(response_, ModelResponse), response_
155
156
  response = response_
157
+ response.model_name = model_name
156
158
  # TODO is `messages` right here? Should it just be new messages?
157
159
  return response, _estimate_usage(chain(messages, [response]))
158
160
 
@@ -163,13 +165,15 @@ class FunctionAgentModel(AgentModel):
163
165
  assert (
164
166
  self.stream_function is not None
165
167
  ), 'FunctionModel must receive a `stream_function` to support streamed requests'
168
+ model_name = f'function:{self.stream_function.__name__}'
169
+
166
170
  response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info))
167
171
 
168
172
  first = await response_stream.peek()
169
173
  if isinstance(first, _utils.Unset):
170
174
  raise ValueError('Stream function must return at least one item')
171
175
 
172
- yield FunctionStreamedResponse(response_stream)
176
+ yield FunctionStreamedResponse(_model_name=model_name, _iter=response_stream)
173
177
 
174
178
 
175
179
  @dataclass
@@ -99,6 +99,7 @@ class GeminiModel(Model):
99
99
  allow_text_result: bool,
100
100
  result_tools: list[ToolDefinition],
101
101
  ) -> GeminiAgentModel:
102
+ check_allow_model_requests()
102
103
  return GeminiAgentModel(
103
104
  http_client=self.http_client,
104
105
  model_name=self.model_name,
@@ -151,7 +152,6 @@ class GeminiAgentModel(AgentModel):
151
152
  allow_text_result: bool,
152
153
  result_tools: list[ToolDefinition],
153
154
  ):
154
- check_allow_model_requests()
155
155
  tools = [_function_from_abstract_tool(t) for t in function_tools]
156
156
  if result_tools:
157
157
  tools += [_function_from_abstract_tool(t) for t in result_tools]
@@ -229,15 +229,13 @@ class GeminiAgentModel(AgentModel):
229
229
  raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
230
230
  yield r
231
231
 
232
- @staticmethod
233
- def _process_response(response: _GeminiResponse) -> ModelResponse:
232
+ def _process_response(self, response: _GeminiResponse) -> ModelResponse:
234
233
  if len(response['candidates']) != 1:
235
234
  raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
236
235
  parts = response['candidates'][0]['content']['parts']
237
- return _process_response_from_parts(parts)
236
+ return _process_response_from_parts(parts, model_name=self.model_name)
238
237
 
239
- @staticmethod
240
- async def _process_streamed_response(http_response: HTTPResponse) -> StreamedResponse:
238
+ async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
241
239
  """Process a streamed response, and prepare a streaming response to return."""
242
240
  aiter_bytes = http_response.aiter_bytes()
243
241
  start_response: _GeminiResponse | None = None
@@ -258,7 +256,7 @@ class GeminiAgentModel(AgentModel):
258
256
  if start_response is None:
259
257
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
260
258
 
261
- return GeminiStreamedResponse(_content=content, _stream=aiter_bytes)
259
+ return GeminiStreamedResponse(_model_name=self.model_name, _content=content, _stream=aiter_bytes)
262
260
 
263
261
  @classmethod
264
262
  def _message_to_gemini_content(
@@ -432,7 +430,9 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
432
430
  return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict()))
433
431
 
434
432
 
435
- def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: datetime | None = None) -> ModelResponse:
433
+ def _process_response_from_parts(
434
+ parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, timestamp: datetime | None = None
435
+ ) -> ModelResponse:
436
436
  items: list[ModelResponsePart] = []
437
437
  for part in parts:
438
438
  if 'text' in part:
@@ -448,7 +448,7 @@ def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: d
448
448
  raise exceptions.UnexpectedModelBehavior(
449
449
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
450
450
  )
451
- return ModelResponse(items, timestamp=timestamp or _utils.now_utc())
451
+ return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc())
452
452
 
453
453
 
454
454
  class _GeminiFunctionCall(TypedDict):