pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -4,23 +4,28 @@ from collections.abc import AsyncIterator, Iterable
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
7
+ from itertools import chain
7
8
  from typing import Literal, Union, overload
8
9
 
9
10
  from httpx import AsyncClient as AsyncHTTPClient
10
11
  from typing_extensions import assert_never
11
12
 
12
13
  from .. import UnexpectedModelBehavior, _utils, result
14
+ from .._utils import guard_tool_call_id as _guard_tool_call_id
13
15
  from ..messages import (
14
- ArgsJson,
15
- Message,
16
- ModelAnyResponse,
17
- ModelStructuredResponse,
18
- ModelTextResponse,
19
- RetryPrompt,
20
- ToolCall,
21
- ToolReturn,
16
+ ModelMessage,
17
+ ModelRequest,
18
+ ModelResponse,
19
+ ModelResponsePart,
20
+ RetryPromptPart,
21
+ SystemPromptPart,
22
+ TextPart,
23
+ ToolCallPart,
24
+ ToolReturnPart,
25
+ UserPromptPart,
22
26
  )
23
- from ..result import Cost
27
+ from ..result import Usage
28
+ from ..settings import ModelSettings
24
29
  from ..tools import ToolDefinition
25
30
  from . import (
26
31
  AgentModel,
@@ -40,7 +45,7 @@ try:
40
45
  except ImportError as _import_error:
41
46
  raise ImportError(
42
47
  'Please install `openai` to use the OpenAI model, '
43
- "you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
48
+ "you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
44
49
  ) from _import_error
45
50
 
46
51
  OpenAIModelName = Union[ChatModel, str]
@@ -66,6 +71,7 @@ class OpenAIModel(Model):
66
71
  self,
67
72
  model_name: OpenAIModelName,
68
73
  *,
74
+ base_url: str | None = None,
69
75
  api_key: str | None = None,
70
76
  openai_client: AsyncOpenAI | None = None,
71
77
  http_client: AsyncHTTPClient | None = None,
@@ -76,22 +82,25 @@ class OpenAIModel(Model):
76
82
  model_name: The name of the OpenAI model to use. List of model names available
77
83
  [here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
78
84
  (Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
85
+ base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable
86
+ will be used if available. Otherwise, defaults to OpenAI's base url.
79
87
  api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
80
88
  will be used if available.
81
89
  openai_client: An existing
82
90
  [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
83
- client to use, if provided, `api_key` and `http_client` must be `None`.
91
+ client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
84
92
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
85
93
  """
86
94
  self.model_name: OpenAIModelName = model_name
87
95
  if openai_client is not None:
88
96
  assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
97
+ assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
89
98
  assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
90
99
  self.client = openai_client
91
100
  elif http_client is not None:
92
- self.client = AsyncOpenAI(api_key=api_key, http_client=http_client)
101
+ self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
93
102
  else:
94
- self.client = AsyncOpenAI(api_key=api_key, http_client=cached_async_http_client())
103
+ self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
95
104
 
96
105
  async def agent_model(
97
106
  self,
@@ -135,28 +144,34 @@ class OpenAIAgentModel(AgentModel):
135
144
  allow_text_result: bool
136
145
  tools: list[chat.ChatCompletionToolParam]
137
146
 
138
- async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
139
- response = await self._completions_create(messages, False)
140
- return self._process_response(response), _map_cost(response)
147
+ async def request(
148
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
149
+ ) -> tuple[ModelResponse, result.Usage]:
150
+ response = await self._completions_create(messages, False, model_settings)
151
+ return self._process_response(response), _map_usage(response)
141
152
 
142
153
  @asynccontextmanager
143
- async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
144
- response = await self._completions_create(messages, True)
154
+ async def request_stream(
155
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
156
+ ) -> AsyncIterator[EitherStreamedResponse]:
157
+ response = await self._completions_create(messages, True, model_settings)
145
158
  async with response:
146
159
  yield await self._process_streamed_response(response)
147
160
 
148
161
  @overload
149
162
  async def _completions_create(
150
- self, messages: list[Message], stream: Literal[True]
163
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
151
164
  ) -> AsyncStream[ChatCompletionChunk]:
152
165
  pass
153
166
 
154
167
  @overload
155
- async def _completions_create(self, messages: list[Message], stream: Literal[False]) -> chat.ChatCompletion:
168
+ async def _completions_create(
169
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
170
+ ) -> chat.ChatCompletion:
156
171
  pass
157
172
 
158
173
  async def _completions_create(
159
- self, messages: list[Message], stream: bool
174
+ self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
160
175
  ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
161
176
  # standalone function to make it easier to override
162
177
  if not self.tools:
@@ -166,7 +181,10 @@ class OpenAIAgentModel(AgentModel):
166
181
  else:
167
182
  tool_choice = 'auto'
168
183
 
169
- openai_messages = [self._map_message(m) for m in messages]
184
+ openai_messages = list(chain(*(self._map_message(m) for m in messages)))
185
+
186
+ model_settings = model_settings or {}
187
+
170
188
  return await self.client.chat.completions.create(
171
189
  model=self.model_name,
172
190
  messages=openai_messages,
@@ -176,27 +194,30 @@ class OpenAIAgentModel(AgentModel):
176
194
  tool_choice=tool_choice or NOT_GIVEN,
177
195
  stream=stream,
178
196
  stream_options={'include_usage': True} if stream else NOT_GIVEN,
197
+ max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
198
+ temperature=model_settings.get('temperature', NOT_GIVEN),
199
+ top_p=model_settings.get('top_p', NOT_GIVEN),
200
+ timeout=model_settings.get('timeout', NOT_GIVEN),
179
201
  )
180
202
 
181
203
  @staticmethod
182
- def _process_response(response: chat.ChatCompletion) -> ModelAnyResponse:
204
+ def _process_response(response: chat.ChatCompletion) -> ModelResponse:
183
205
  """Process a non-streamed response, and prepare a message to return."""
184
206
  timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
185
207
  choice = response.choices[0]
208
+ items: list[ModelResponsePart] = []
209
+ if choice.message.content is not None:
210
+ items.append(TextPart(choice.message.content))
186
211
  if choice.message.tool_calls is not None:
187
- return ModelStructuredResponse(
188
- [ToolCall.from_json(c.function.name, c.function.arguments, c.id) for c in choice.message.tool_calls],
189
- timestamp=timestamp,
190
- )
191
- else:
192
- assert choice.message.content is not None, choice
193
- return ModelTextResponse(choice.message.content, timestamp=timestamp)
212
+ for c in choice.message.tool_calls:
213
+ items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
214
+ return ModelResponse(items, timestamp=timestamp)
194
215
 
195
216
  @staticmethod
196
217
  async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
197
218
  """Process a streamed response, and prepare a streaming response to return."""
198
219
  timestamp: datetime | None = None
199
- start_cost = Cost()
220
+ start_usage = Usage()
200
221
  # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
201
222
  while True:
202
223
  try:
@@ -205,63 +226,73 @@ class OpenAIAgentModel(AgentModel):
205
226
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
206
227
 
207
228
  timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
208
- start_cost += _map_cost(chunk)
229
+ start_usage += _map_usage(chunk)
209
230
 
210
231
  if chunk.choices:
211
232
  delta = chunk.choices[0].delta
212
233
 
213
234
  if delta.content is not None:
214
- return OpenAIStreamTextResponse(delta.content, response, timestamp, start_cost)
235
+ return OpenAIStreamTextResponse(delta.content, response, timestamp, start_usage)
215
236
  elif delta.tool_calls is not None:
216
237
  return OpenAIStreamStructuredResponse(
217
238
  response,
218
239
  {c.index: c for c in delta.tool_calls},
219
240
  timestamp,
220
- start_cost,
241
+ start_usage,
221
242
  )
222
243
  # else continue until we get either delta.content or delta.tool_calls
223
244
 
224
- @staticmethod
225
- def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
245
+ @classmethod
246
+ def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
226
247
  """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
227
- if message.role == 'system':
228
- # SystemPrompt ->
229
- return chat.ChatCompletionSystemMessageParam(role='system', content=message.content)
230
- elif message.role == 'user':
231
- # UserPrompt ->
232
- return chat.ChatCompletionUserMessageParam(role='user', content=message.content)
233
- elif message.role == 'tool-return':
234
- # ToolReturn ->
235
- return chat.ChatCompletionToolMessageParam(
236
- role='tool',
237
- tool_call_id=_guard_tool_id(message),
238
- content=message.model_response_str(),
239
- )
240
- elif message.role == 'retry-prompt':
241
- # RetryPrompt ->
242
- if message.tool_name is None:
243
- return chat.ChatCompletionUserMessageParam(role='user', content=message.model_response())
244
- else:
245
- return chat.ChatCompletionToolMessageParam(
246
- role='tool',
247
- tool_call_id=_guard_tool_id(message),
248
- content=message.model_response(),
249
- )
250
- elif message.role == 'model-text-response':
251
- # ModelTextResponse ->
252
- return chat.ChatCompletionAssistantMessageParam(role='assistant', content=message.content)
253
- elif message.role == 'model-structured-response':
254
- assert (
255
- message.role == 'model-structured-response'
256
- ), f'Expected role to be "llm-tool-calls", got {message.role}'
257
- # ModelStructuredResponse ->
258
- return chat.ChatCompletionAssistantMessageParam(
259
- role='assistant',
260
- tool_calls=[_map_tool_call(t) for t in message.calls],
261
- )
248
+ if isinstance(message, ModelRequest):
249
+ yield from cls._map_user_message(message)
250
+ elif isinstance(message, ModelResponse):
251
+ texts: list[str] = []
252
+ tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
253
+ for item in message.parts:
254
+ if isinstance(item, TextPart):
255
+ texts.append(item.content)
256
+ elif isinstance(item, ToolCallPart):
257
+ tool_calls.append(_map_tool_call(item))
258
+ else:
259
+ assert_never(item)
260
+ message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
261
+ if texts:
262
+ # Note: model responses from this model should only have one text item, so the following
263
+ # shouldn't merge multiple texts into one unless you switch models between runs:
264
+ message_param['content'] = '\n\n'.join(texts)
265
+ if tool_calls:
266
+ message_param['tool_calls'] = tool_calls
267
+ yield message_param
262
268
  else:
263
269
  assert_never(message)
264
270
 
271
+ @classmethod
272
+ def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
273
+ for part in message.parts:
274
+ if isinstance(part, SystemPromptPart):
275
+ yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
276
+ elif isinstance(part, UserPromptPart):
277
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
278
+ elif isinstance(part, ToolReturnPart):
279
+ yield chat.ChatCompletionToolMessageParam(
280
+ role='tool',
281
+ tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
282
+ content=part.model_response_str(),
283
+ )
284
+ elif isinstance(part, RetryPromptPart):
285
+ if part.tool_name is None:
286
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
287
+ else:
288
+ yield chat.ChatCompletionToolMessageParam(
289
+ role='tool',
290
+ tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
291
+ content=part.model_response(),
292
+ )
293
+ else:
294
+ assert_never(part)
295
+
265
296
 
266
297
  @dataclass
267
298
  class OpenAIStreamTextResponse(StreamTextResponse):
@@ -270,7 +301,7 @@ class OpenAIStreamTextResponse(StreamTextResponse):
270
301
  _first: str | None
271
302
  _response: AsyncStream[ChatCompletionChunk]
272
303
  _timestamp: datetime
273
- _cost: result.Cost
304
+ _usage: result.Usage
274
305
  _buffer: list[str] = field(default_factory=list, init=False)
275
306
 
276
307
  async def __anext__(self) -> None:
@@ -280,7 +311,7 @@ class OpenAIStreamTextResponse(StreamTextResponse):
280
311
  return None
281
312
 
282
313
  chunk = await self._response.__anext__()
283
- self._cost += _map_cost(chunk)
314
+ self._usage += _map_usage(chunk)
284
315
  try:
285
316
  choice = chunk.choices[0]
286
317
  except IndexError:
@@ -296,8 +327,8 @@ class OpenAIStreamTextResponse(StreamTextResponse):
296
327
  yield from self._buffer
297
328
  self._buffer.clear()
298
329
 
299
- def cost(self) -> Cost:
300
- return self._cost
330
+ def usage(self) -> Usage:
331
+ return self._usage
301
332
 
302
333
  def timestamp(self) -> datetime:
303
334
  return self._timestamp
@@ -310,11 +341,11 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
310
341
  _response: AsyncStream[ChatCompletionChunk]
311
342
  _delta_tool_calls: dict[int, ChoiceDeltaToolCall]
312
343
  _timestamp: datetime
313
- _cost: result.Cost
344
+ _usage: result.Usage
314
345
 
315
346
  async def __anext__(self) -> None:
316
347
  chunk = await self._response.__anext__()
317
- self._cost += _map_cost(chunk)
348
+ self._usage += _map_usage(chunk)
318
349
  try:
319
350
  choice = chunk.choices[0]
320
351
  except IndexError:
@@ -335,48 +366,41 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
335
366
  else:
336
367
  self._delta_tool_calls[new.index] = new
337
368
 
338
- def get(self, *, final: bool = False) -> ModelStructuredResponse:
339
- calls: list[ToolCall] = []
369
+ def get(self, *, final: bool = False) -> ModelResponse:
370
+ items: list[ModelResponsePart] = []
340
371
  for c in self._delta_tool_calls.values():
341
372
  if f := c.function:
342
373
  if f.name is not None and f.arguments is not None:
343
- calls.append(ToolCall.from_json(f.name, f.arguments, c.id))
374
+ items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
344
375
 
345
- return ModelStructuredResponse(calls, timestamp=self._timestamp)
376
+ return ModelResponse(items, timestamp=self._timestamp)
346
377
 
347
- def cost(self) -> Cost:
348
- return self._cost
378
+ def usage(self) -> Usage:
379
+ return self._usage
349
380
 
350
381
  def timestamp(self) -> datetime:
351
382
  return self._timestamp
352
383
 
353
384
 
354
- def _guard_tool_id(t: ToolCall | ToolReturn | RetryPrompt) -> str:
355
- """Type guard that checks a `tool_id` is not None both for static typing and runtime."""
356
- assert t.tool_id is not None, f'OpenAI requires `tool_id` to be set: {t}'
357
- return t.tool_id
358
-
359
-
360
- def _map_tool_call(t: ToolCall) -> chat.ChatCompletionMessageToolCallParam:
361
- assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
385
+ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
362
386
  return chat.ChatCompletionMessageToolCallParam(
363
- id=_guard_tool_id(t),
387
+ id=_guard_tool_call_id(t=t, model_source='OpenAI'),
364
388
  type='function',
365
- function={'name': t.tool_name, 'arguments': t.args.args_json},
389
+ function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
366
390
  )
367
391
 
368
392
 
369
- def _map_cost(response: chat.ChatCompletion | ChatCompletionChunk) -> result.Cost:
393
+ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> result.Usage:
370
394
  usage = response.usage
371
395
  if usage is None:
372
- return result.Cost()
396
+ return result.Usage()
373
397
  else:
374
398
  details: dict[str, int] = {}
375
399
  if usage.completion_tokens_details is not None:
376
400
  details.update(usage.completion_tokens_details.model_dump(exclude_none=True))
377
401
  if usage.prompt_tokens_details is not None:
378
402
  details.update(usage.prompt_tokens_details.model_dump(exclude_none=True))
379
- return result.Cost(
403
+ return result.Usage(
380
404
  request_tokens=usage.prompt_tokens,
381
405
  response_tokens=usage.completion_tokens,
382
406
  total_tokens=usage.total_tokens,
@@ -9,18 +9,20 @@ from datetime import date, datetime, timedelta
9
9
  from typing import Any, Literal
10
10
 
11
11
  import pydantic_core
12
+ from typing_extensions import assert_never
12
13
 
13
14
  from .. import _utils
14
15
  from ..messages import (
15
- Message,
16
- ModelAnyResponse,
17
- ModelStructuredResponse,
18
- ModelTextResponse,
19
- RetryPrompt,
20
- ToolCall,
21
- ToolReturn,
16
+ ModelMessage,
17
+ ModelRequest,
18
+ ModelResponse,
19
+ RetryPromptPart,
20
+ TextPart,
21
+ ToolCallPart,
22
+ ToolReturnPart,
22
23
  )
23
- from ..result import Cost
24
+ from ..result import Usage
25
+ from ..settings import ModelSettings
24
26
  from ..tools import ToolDefinition
25
27
  from . import (
26
28
  AgentModel,
@@ -29,6 +31,7 @@ from . import (
29
31
  StreamStructuredResponse,
30
32
  StreamTextResponse,
31
33
  )
34
+ from .function import _estimate_string_usage, _estimate_usage # pyright: ignore[reportPrivateUsage]
32
35
 
33
36
 
34
37
  @dataclass
@@ -127,74 +130,85 @@ class TestAgentModel(AgentModel):
127
130
  result_tools: list[ToolDefinition]
128
131
  seed: int
129
132
 
130
- async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, Cost]:
131
- return self._request(messages), Cost()
133
+ async def request(
134
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
135
+ ) -> tuple[ModelResponse, Usage]:
136
+ model_response = self._request(messages, model_settings)
137
+ usage = _estimate_usage([*messages, model_response])
138
+ return model_response, usage
132
139
 
133
140
  @asynccontextmanager
134
- async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
135
- msg = self._request(messages)
136
- cost = Cost()
137
- if isinstance(msg, ModelTextResponse):
138
- yield TestStreamTextResponse(msg.content, cost)
141
+ async def request_stream(
142
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
143
+ ) -> AsyncIterator[EitherStreamedResponse]:
144
+ msg = self._request(messages, model_settings)
145
+ usage = _estimate_usage(messages)
146
+
147
+ # TODO: Rework this once we make StreamTextResponse more general
148
+ texts: list[str] = []
149
+ tool_calls: list[ToolCallPart] = []
150
+ for item in msg.parts:
151
+ if isinstance(item, TextPart):
152
+ texts.append(item.content)
153
+ elif isinstance(item, ToolCallPart):
154
+ tool_calls.append(item)
155
+ else:
156
+ assert_never(item)
157
+
158
+ if texts:
159
+ yield TestStreamTextResponse('\n\n'.join(texts), usage)
139
160
  else:
140
- yield TestStreamStructuredResponse(msg, cost)
161
+ yield TestStreamStructuredResponse(msg, usage)
141
162
 
142
163
  def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
143
164
  return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
144
165
 
145
- def _request(self, messages: list[Message]) -> ModelAnyResponse:
166
+ def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | None) -> ModelResponse:
146
167
  # if there are tools, the first thing we want to do is call all of them
147
- if self.tool_calls and not any(m.role == 'model-structured-response' for m in messages):
148
- calls = [ToolCall.from_dict(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
149
- return ModelStructuredResponse(calls=calls)
150
-
151
- # get messages since the last model response
152
- new_messages = _get_new_messages(messages)
153
-
154
- # check if there are any retry prompts, if so retry them
155
- new_retry_names = {m.tool_name for m in new_messages if isinstance(m, RetryPrompt)}
156
- if new_retry_names:
157
- calls = [
158
- ToolCall.from_dict(name, self.gen_tool_args(args))
159
- for name, args in self.tool_calls
160
- if name in new_retry_names
161
- ]
162
- return ModelStructuredResponse(calls=calls)
168
+ if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
169
+ return ModelResponse(
170
+ parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
171
+ )
172
+
173
+ if messages:
174
+ last_message = messages[-1]
175
+ assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.'
176
+
177
+ # check if there are any retry prompts, if so retry them
178
+ new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
179
+ if new_retry_names:
180
+ return ModelResponse(
181
+ parts=[
182
+ ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
183
+ for name, args in self.tool_calls
184
+ if name in new_retry_names
185
+ ]
186
+ )
163
187
 
164
188
  if response_text := self.result.left:
165
189
  if response_text.value is None:
166
190
  # build up details of tool responses
167
191
  output: dict[str, Any] = {}
168
192
  for message in messages:
169
- if isinstance(message, ToolReturn):
170
- output[message.tool_name] = message.content
193
+ if isinstance(message, ModelRequest):
194
+ for part in message.parts:
195
+ if isinstance(part, ToolReturnPart):
196
+ output[part.tool_name] = part.content
171
197
  if output:
172
- return ModelTextResponse(content=pydantic_core.to_json(output).decode())
198
+ return ModelResponse.from_text(pydantic_core.to_json(output).decode())
173
199
  else:
174
- return ModelTextResponse(content='success (no tool calls)')
200
+ return ModelResponse.from_text('success (no tool calls)')
175
201
  else:
176
- return ModelTextResponse(content=response_text.value)
202
+ return ModelResponse.from_text(response_text.value)
177
203
  else:
178
204
  assert self.result_tools, 'No result tools provided'
179
205
  custom_result_args = self.result.right
180
206
  result_tool = self.result_tools[self.seed % len(self.result_tools)]
181
207
  if custom_result_args is not None:
182
- return ModelStructuredResponse(calls=[ToolCall.from_dict(result_tool.name, custom_result_args)])
208
+ return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, custom_result_args)])
183
209
  else:
184
210
  response_args = self.gen_tool_args(result_tool)
185
- return ModelStructuredResponse(calls=[ToolCall.from_dict(result_tool.name, response_args)])
186
-
187
-
188
- def _get_new_messages(messages: list[Message]) -> list[Message]:
189
- last_model_index = None
190
- for i, m in enumerate(messages):
191
- if m.role in ('model-structured-response', 'model-text-response'):
192
- last_model_index = i
193
-
194
- if last_model_index is not None:
195
- return messages[last_model_index + 1 :]
196
- else:
197
- return []
211
+ return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, response_args)])
198
212
 
199
213
 
200
214
  @dataclass
@@ -202,7 +216,7 @@ class TestStreamTextResponse(StreamTextResponse):
202
216
  """A text response that streams test data."""
203
217
 
204
218
  _text: str
205
- _cost: Cost
219
+ _usage: Usage
206
220
  _iter: Iterator[str] = field(init=False)
207
221
  _timestamp: datetime = field(default_factory=_utils.now_utc)
208
222
  _buffer: list[str] = field(default_factory=list, init=False)
@@ -217,14 +231,17 @@ class TestStreamTextResponse(StreamTextResponse):
217
231
  self._iter = iter(words)
218
232
 
219
233
  async def __anext__(self) -> None:
220
- self._buffer.append(_utils.sync_anext(self._iter))
234
+ next_str = _utils.sync_anext(self._iter)
235
+ response_tokens = _estimate_string_usage(next_str)
236
+ self._usage += Usage(response_tokens=response_tokens, total_tokens=response_tokens)
237
+ self._buffer.append(next_str)
221
238
 
222
239
  def get(self, *, final: bool = False) -> Iterable[str]:
223
240
  yield from self._buffer
224
241
  self._buffer.clear()
225
242
 
226
- def cost(self) -> Cost:
227
- return self._cost
243
+ def usage(self) -> Usage:
244
+ return self._usage
228
245
 
229
246
  def timestamp(self) -> datetime:
230
247
  return self._timestamp
@@ -234,19 +251,19 @@ class TestStreamTextResponse(StreamTextResponse):
234
251
  class TestStreamStructuredResponse(StreamStructuredResponse):
235
252
  """A structured response that streams test data."""
236
253
 
237
- _structured_response: ModelStructuredResponse
238
- _cost: Cost
254
+ _structured_response: ModelResponse
255
+ _usage: Usage
239
256
  _iter: Iterator[None] = field(default_factory=lambda: iter([None]))
240
257
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
241
258
 
242
259
  async def __anext__(self) -> None:
243
260
  return _utils.sync_anext(self._iter)
244
261
 
245
- def get(self, *, final: bool = False) -> ModelStructuredResponse:
262
+ def get(self, *, final: bool = False) -> ModelResponse:
246
263
  return self._structured_response
247
264
 
248
- def cost(self) -> Cost:
249
- return self._cost
265
+ def usage(self) -> Usage:
266
+ return self._usage
250
267
 
251
268
  def timestamp(self) -> datetime:
252
269
  return self._timestamp
@@ -21,7 +21,7 @@ try:
21
21
  except ImportError as _import_error:
22
22
  raise ImportError(
23
23
  'Please install `google-auth` to use the VertexAI model, '
24
- "you can use the `vertexai` optional group — `pip install 'pydantic-ai[vertexai]'`"
24
+ "you can use the `vertexai` optional group — `pip install 'pydantic-ai-slim[vertexai]'`"
25
25
  ) from _import_error
26
26
 
27
27
  VERTEX_AI_URL_TEMPLATE = (
@@ -114,7 +114,7 @@ class VertexAIModel(Model):
114
114
  allow_text_result: bool,
115
115
  result_tools: list[ToolDefinition],
116
116
  ) -> GeminiAgentModel:
117
- url, auth = await self._ainit()
117
+ url, auth = await self.ainit()
118
118
  return GeminiAgentModel(
119
119
  http_client=self.http_client,
120
120
  model_name=self.model_name,
@@ -125,7 +125,11 @@ class VertexAIModel(Model):
125
125
  result_tools=result_tools,
126
126
  )
127
127
 
128
- async def _ainit(self) -> tuple[str, BearerTokenAuth]:
128
+ async def ainit(self) -> tuple[str, BearerTokenAuth]:
129
+ """Initialize the model, setting the URL and auth.
130
+
131
+ This will raise an error if authentication fails.
132
+ """
129
133
  if self.url is not None and self.auth is not None:
130
134
  return self.url, self.auth
131
135