pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterator, Iterable
3
+ from collections.abc import AsyncIterable, AsyncIterator, Iterable
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
@@ -10,13 +10,14 @@ from typing import Literal, Union, overload
10
10
  from httpx import AsyncClient as AsyncHTTPClient
11
11
  from typing_extensions import assert_never
12
12
 
13
- from .. import UnexpectedModelBehavior, _utils, result
13
+ from .. import UnexpectedModelBehavior, _utils, usage
14
14
  from .._utils import guard_tool_call_id as _guard_tool_call_id
15
15
  from ..messages import (
16
16
  ModelMessage,
17
17
  ModelRequest,
18
18
  ModelResponse,
19
19
  ModelResponsePart,
20
+ ModelResponseStreamEvent,
20
21
  RetryPromptPart,
21
22
  SystemPromptPart,
22
23
  TextPart,
@@ -24,15 +25,12 @@ from ..messages import (
24
25
  ToolReturnPart,
25
26
  UserPromptPart,
26
27
  )
27
- from ..result import Usage
28
28
  from ..settings import ModelSettings
29
29
  from ..tools import ToolDefinition
30
30
  from . import (
31
31
  AgentModel,
32
- EitherStreamedResponse,
33
32
  Model,
34
- StreamStructuredResponse,
35
- StreamTextResponse,
33
+ StreamedResponse,
36
34
  cached_async_http_client,
37
35
  check_allow_model_requests,
38
36
  )
@@ -41,7 +39,6 @@ try:
41
39
  from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream
42
40
  from openai.types import ChatModel, chat
43
41
  from openai.types.chat import ChatCompletionChunk
44
- from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
45
42
  except ImportError as _import_error:
46
43
  raise ImportError(
47
44
  'Please install `openai` to use the OpenAI model, '
@@ -54,6 +51,8 @@ Using this more broad type for the model name instead of the ChatModel definitio
54
51
  allows this model to be used more easily with other model types (ie, Ollama)
55
52
  """
56
53
 
54
+ OpenAISystemPromptRole = Literal['system', 'developer', 'user']
55
+
57
56
 
58
57
  @dataclass(init=False)
59
58
  class OpenAIModel(Model):
@@ -66,6 +65,7 @@ class OpenAIModel(Model):
66
65
 
67
66
  model_name: OpenAIModelName
68
67
  client: AsyncOpenAI = field(repr=False)
68
+ system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
69
69
 
70
70
  def __init__(
71
71
  self,
@@ -75,6 +75,7 @@ class OpenAIModel(Model):
75
75
  api_key: str | None = None,
76
76
  openai_client: AsyncOpenAI | None = None,
77
77
  http_client: AsyncHTTPClient | None = None,
78
+ system_prompt_role: OpenAISystemPromptRole | None = None,
78
79
  ):
79
80
  """Initialize an OpenAI model.
80
81
 
@@ -90,6 +91,8 @@ class OpenAIModel(Model):
90
91
  [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
91
92
  client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
92
93
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
94
+ system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
95
+ In the future, this may be inferred from the model name.
93
96
  """
94
97
  self.model_name: OpenAIModelName = model_name
95
98
  if openai_client is not None:
@@ -101,6 +104,7 @@ class OpenAIModel(Model):
101
104
  self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
102
105
  else:
103
106
  self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
107
+ self.system_prompt_role = system_prompt_role
104
108
 
105
109
  async def agent_model(
106
110
  self,
@@ -118,6 +122,7 @@ class OpenAIModel(Model):
118
122
  self.model_name,
119
123
  allow_text_result,
120
124
  tools,
125
+ self.system_prompt_role,
121
126
  )
122
127
 
123
128
  def name(self) -> str:
@@ -143,17 +148,18 @@ class OpenAIAgentModel(AgentModel):
143
148
  model_name: OpenAIModelName
144
149
  allow_text_result: bool
145
150
  tools: list[chat.ChatCompletionToolParam]
151
+ system_prompt_role: OpenAISystemPromptRole | None
146
152
 
147
153
  async def request(
148
154
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
149
- ) -> tuple[ModelResponse, result.Usage]:
155
+ ) -> tuple[ModelResponse, usage.Usage]:
150
156
  response = await self._completions_create(messages, False, model_settings)
151
157
  return self._process_response(response), _map_usage(response)
152
158
 
153
159
  @asynccontextmanager
154
160
  async def request_stream(
155
161
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
156
- ) -> AsyncIterator[EitherStreamedResponse]:
162
+ ) -> AsyncIterator[StreamedResponse]:
157
163
  response = await self._completions_create(messages, True, model_settings)
158
164
  async with response:
159
165
  yield await self._process_streamed_response(response)
@@ -189,7 +195,7 @@ class OpenAIAgentModel(AgentModel):
189
195
  model=self.model_name,
190
196
  messages=openai_messages,
191
197
  n=1,
192
- parallel_tool_calls=True if self.tools else NOT_GIVEN,
198
+ parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
193
199
  tools=self.tools or NOT_GIVEN,
194
200
  tool_choice=tool_choice or NOT_GIVEN,
195
201
  stream=stream,
@@ -200,8 +206,7 @@ class OpenAIAgentModel(AgentModel):
200
206
  timeout=model_settings.get('timeout', NOT_GIVEN),
201
207
  )
202
208
 
203
- @staticmethod
204
- def _process_response(response: chat.ChatCompletion) -> ModelResponse:
209
+ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
205
210
  """Process a non-streamed response, and prepare a message to return."""
206
211
  timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
207
212
  choice = response.choices[0]
@@ -211,42 +216,25 @@ class OpenAIAgentModel(AgentModel):
211
216
  if choice.message.tool_calls is not None:
212
217
  for c in choice.message.tool_calls:
213
218
  items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
214
- return ModelResponse(items, timestamp=timestamp)
219
+ return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
215
220
 
216
- @staticmethod
217
- async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
221
+ async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
218
222
  """Process a streamed response, and prepare a streaming response to return."""
219
- timestamp: datetime | None = None
220
- start_usage = Usage()
221
- # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
222
- while True:
223
- try:
224
- chunk = await response.__anext__()
225
- except StopAsyncIteration as e:
226
- raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
227
-
228
- timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
229
- start_usage += _map_usage(chunk)
230
-
231
- if chunk.choices:
232
- delta = chunk.choices[0].delta
233
-
234
- if delta.content is not None:
235
- return OpenAIStreamTextResponse(delta.content, response, timestamp, start_usage)
236
- elif delta.tool_calls is not None:
237
- return OpenAIStreamStructuredResponse(
238
- response,
239
- {c.index: c for c in delta.tool_calls},
240
- timestamp,
241
- start_usage,
242
- )
243
- # else continue until we get either delta.content or delta.tool_calls
223
+ peekable_response = _utils.PeekableAsyncStream(response)
224
+ first_chunk = await peekable_response.peek()
225
+ if isinstance(first_chunk, _utils.Unset):
226
+ raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
227
+
228
+ return OpenAIStreamedResponse(
229
+ _model_name=self.model_name,
230
+ _response=peekable_response,
231
+ _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
232
+ )
244
233
 
245
- @classmethod
246
- def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
234
+ def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
247
235
  """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
248
236
  if isinstance(message, ModelRequest):
249
- yield from cls._map_user_message(message)
237
+ yield from self._map_user_message(message)
250
238
  elif isinstance(message, ModelResponse):
251
239
  texts: list[str] = []
252
240
  tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
@@ -268,11 +256,15 @@ class OpenAIAgentModel(AgentModel):
268
256
  else:
269
257
  assert_never(message)
270
258
 
271
- @classmethod
272
- def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
259
+ def _map_user_message(self, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
273
260
  for part in message.parts:
274
261
  if isinstance(part, SystemPromptPart):
275
- yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
262
+ if self.system_prompt_role == 'developer':
263
+ yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content)
264
+ elif self.system_prompt_role == 'user':
265
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
266
+ else:
267
+ yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
276
268
  elif isinstance(part, UserPromptPart):
277
269
  yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
278
270
  elif isinstance(part, ToolReturnPart):
@@ -295,88 +287,35 @@ class OpenAIAgentModel(AgentModel):
295
287
 
296
288
 
297
289
  @dataclass
298
- class OpenAIStreamTextResponse(StreamTextResponse):
299
- """Implementation of `StreamTextResponse` for OpenAI models."""
300
-
301
- _first: str | None
302
- _response: AsyncStream[ChatCompletionChunk]
303
- _timestamp: datetime
304
- _usage: result.Usage
305
- _buffer: list[str] = field(default_factory=list, init=False)
306
-
307
- async def __anext__(self) -> None:
308
- if self._first is not None:
309
- self._buffer.append(self._first)
310
- self._first = None
311
- return None
312
-
313
- chunk = await self._response.__anext__()
314
- self._usage += _map_usage(chunk)
315
- try:
316
- choice = chunk.choices[0]
317
- except IndexError:
318
- raise StopAsyncIteration()
319
-
320
- # we don't raise StopAsyncIteration on the last chunk because usage comes after this
321
- if choice.finish_reason is None:
322
- assert choice.delta.content is not None, f'Expected delta with content, invalid chunk: {chunk!r}'
323
- if choice.delta.content is not None:
324
- self._buffer.append(choice.delta.content)
325
-
326
- def get(self, *, final: bool = False) -> Iterable[str]:
327
- yield from self._buffer
328
- self._buffer.clear()
329
-
330
- def usage(self) -> Usage:
331
- return self._usage
332
-
333
- def timestamp(self) -> datetime:
334
- return self._timestamp
335
-
336
-
337
- @dataclass
338
- class OpenAIStreamStructuredResponse(StreamStructuredResponse):
339
- """Implementation of `StreamStructuredResponse` for OpenAI models."""
290
+ class OpenAIStreamedResponse(StreamedResponse):
291
+ """Implementation of `StreamedResponse` for OpenAI models."""
340
292
 
341
- _response: AsyncStream[ChatCompletionChunk]
342
- _delta_tool_calls: dict[int, ChoiceDeltaToolCall]
293
+ _response: AsyncIterable[ChatCompletionChunk]
343
294
  _timestamp: datetime
344
- _usage: result.Usage
345
-
346
- async def __anext__(self) -> None:
347
- chunk = await self._response.__anext__()
348
- self._usage += _map_usage(chunk)
349
- try:
350
- choice = chunk.choices[0]
351
- except IndexError:
352
- raise StopAsyncIteration()
353
-
354
- if choice.finish_reason is not None:
355
- raise StopAsyncIteration()
356
-
357
- assert choice.delta.content is None, f'Expected tool calls, got content instead, invalid chunk: {chunk!r}'
358
-
359
- for new in choice.delta.tool_calls or []:
360
- if current := self._delta_tool_calls.get(new.index):
361
- if current.function is None:
362
- current.function = new.function
363
- elif new.function is not None:
364
- current.function.name = _utils.add_optional(current.function.name, new.function.name)
365
- current.function.arguments = _utils.add_optional(current.function.arguments, new.function.arguments)
366
- else:
367
- self._delta_tool_calls[new.index] = new
368
-
369
- def get(self, *, final: bool = False) -> ModelResponse:
370
- items: list[ModelResponsePart] = []
371
- for c in self._delta_tool_calls.values():
372
- if f := c.function:
373
- if f.name is not None and f.arguments is not None:
374
- items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
375
295
 
376
- return ModelResponse(items, timestamp=self._timestamp)
296
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
297
+ async for chunk in self._response:
298
+ self._usage += _map_usage(chunk)
377
299
 
378
- def usage(self) -> Usage:
379
- return self._usage
300
+ try:
301
+ choice = chunk.choices[0]
302
+ except IndexError:
303
+ continue
304
+
305
+ # Handle the text part of the response
306
+ content = choice.delta.content
307
+ if content is not None:
308
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
309
+
310
+ for dtc in choice.delta.tool_calls or []:
311
+ maybe_event = self._parts_manager.handle_tool_call_delta(
312
+ vendor_part_id=dtc.index,
313
+ tool_name=dtc.function and dtc.function.name,
314
+ args=dtc.function and dtc.function.arguments,
315
+ tool_call_id=dtc.id,
316
+ )
317
+ if maybe_event is not None:
318
+ yield maybe_event
380
319
 
381
320
  def timestamp(self) -> datetime:
382
321
  return self._timestamp
@@ -390,19 +329,19 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
390
329
  )
391
330
 
392
331
 
393
- def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> result.Usage:
394
- usage = response.usage
395
- if usage is None:
396
- return result.Usage()
332
+ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usage:
333
+ response_usage = response.usage
334
+ if response_usage is None:
335
+ return usage.Usage()
397
336
  else:
398
337
  details: dict[str, int] = {}
399
- if usage.completion_tokens_details is not None:
400
- details.update(usage.completion_tokens_details.model_dump(exclude_none=True))
401
- if usage.prompt_tokens_details is not None:
402
- details.update(usage.prompt_tokens_details.model_dump(exclude_none=True))
403
- return result.Usage(
404
- request_tokens=usage.prompt_tokens,
405
- response_tokens=usage.completion_tokens,
406
- total_tokens=usage.total_tokens,
338
+ if response_usage.completion_tokens_details is not None:
339
+ details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True))
340
+ if response_usage.prompt_tokens_details is not None:
341
+ details.update(response_usage.prompt_tokens_details.model_dump(exclude_none=True))
342
+ return usage.Usage(
343
+ request_tokens=response_usage.prompt_tokens,
344
+ response_tokens=response_usage.completion_tokens,
345
+ total_tokens=response_usage.total_tokens,
407
346
  details=details,
408
347
  )
@@ -2,21 +2,22 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import re
4
4
  import string
5
- from collections.abc import AsyncIterator, Iterable, Iterator
5
+ from collections.abc import AsyncIterator, Iterable
6
6
  from contextlib import asynccontextmanager
7
- from dataclasses import dataclass, field
7
+ from dataclasses import InitVar, dataclass, field
8
8
  from datetime import date, datetime, timedelta
9
9
  from typing import Any, Literal
10
10
 
11
11
  import pydantic_core
12
- from typing_extensions import assert_never
13
12
 
14
13
  from .. import _utils
15
14
  from ..messages import (
15
+ ArgsJson,
16
16
  ModelMessage,
17
17
  ModelRequest,
18
18
  ModelResponse,
19
19
  ModelResponsePart,
20
+ ModelResponseStreamEvent,
20
21
  RetryPromptPart,
21
22
  TextPart,
22
23
  ToolCallPart,
@@ -27,12 +28,10 @@ from ..settings import ModelSettings
27
28
  from ..tools import ToolDefinition
28
29
  from . import (
29
30
  AgentModel,
30
- EitherStreamedResponse,
31
31
  Model,
32
- StreamStructuredResponse,
33
- StreamTextResponse,
32
+ StreamedResponse,
34
33
  )
35
- from .function import _estimate_string_usage, _estimate_usage # pyright: ignore[reportPrivateUsage]
34
+ from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
36
35
 
37
36
 
38
37
  @dataclass
@@ -130,6 +129,7 @@ class TestAgentModel(AgentModel):
130
129
  result: _utils.Either[str | None, Any | None]
131
130
  result_tools: list[ToolDefinition]
132
131
  seed: int
132
+ model_name: str = 'test'
133
133
 
134
134
  async def request(
135
135
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
@@ -141,25 +141,9 @@ class TestAgentModel(AgentModel):
141
141
  @asynccontextmanager
142
142
  async def request_stream(
143
143
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
144
- ) -> AsyncIterator[EitherStreamedResponse]:
145
- msg = self._request(messages, model_settings)
146
- usage = _estimate_usage(messages)
147
-
148
- # TODO: Rework this once we make StreamTextResponse more general
149
- texts: list[str] = []
150
- tool_calls: list[ToolCallPart] = []
151
- for item in msg.parts:
152
- if isinstance(item, TextPart):
153
- texts.append(item.content)
154
- elif isinstance(item, ToolCallPart):
155
- tool_calls.append(item)
156
- else:
157
- assert_never(item)
158
-
159
- if texts:
160
- yield TestStreamTextResponse('\n\n'.join(texts), usage)
161
- else:
162
- yield TestStreamStructuredResponse(msg, usage)
144
+ ) -> AsyncIterator[StreamedResponse]:
145
+ model_response = self._request(messages, model_settings)
146
+ yield TestStreamedResponse(_model_name=self.model_name, _structured_response=model_response, _messages=messages)
163
147
 
164
148
  def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
165
149
  return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
@@ -168,7 +152,8 @@ class TestAgentModel(AgentModel):
168
152
  # if there are tools, the first thing we want to do is call all of them
169
153
  if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
170
154
  return ModelResponse(
171
- parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
155
+ parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls],
156
+ model_name=self.model_name,
172
157
  )
173
158
 
174
159
  if messages:
@@ -194,7 +179,7 @@ class TestAgentModel(AgentModel):
194
179
  if tool.name in new_retry_names
195
180
  ]
196
181
  )
197
- return ModelResponse(parts=retry_parts)
182
+ return ModelResponse(parts=retry_parts, model_name=self.model_name)
198
183
 
199
184
  if response_text := self.result.left:
200
185
  if response_text.value is None:
@@ -206,75 +191,60 @@ class TestAgentModel(AgentModel):
206
191
  if isinstance(part, ToolReturnPart):
207
192
  output[part.tool_name] = part.content
208
193
  if output:
209
- return ModelResponse.from_text(pydantic_core.to_json(output).decode())
194
+ return ModelResponse(
195
+ parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self.model_name
196
+ )
210
197
  else:
211
- return ModelResponse.from_text('success (no tool calls)')
198
+ return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.model_name)
212
199
  else:
213
- return ModelResponse.from_text(response_text.value)
200
+ return ModelResponse(parts=[TextPart(response_text.value)], model_name=self.model_name)
214
201
  else:
215
202
  assert self.result_tools, 'No result tools provided'
216
203
  custom_result_args = self.result.right
217
204
  result_tool = self.result_tools[self.seed % len(self.result_tools)]
218
205
  if custom_result_args is not None:
219
- return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, custom_result_args)])
206
+ return ModelResponse(
207
+ parts=[ToolCallPart.from_raw_args(result_tool.name, custom_result_args)], model_name=self.model_name
208
+ )
220
209
  else:
221
210
  response_args = self.gen_tool_args(result_tool)
222
- return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, response_args)])
211
+ return ModelResponse(
212
+ parts=[ToolCallPart.from_raw_args(result_tool.name, response_args)], model_name=self.model_name
213
+ )
223
214
 
224
215
 
225
216
  @dataclass
226
- class TestStreamTextResponse(StreamTextResponse):
227
- """A text response that streams test data."""
228
-
229
- _text: str
230
- _usage: Usage
231
- _iter: Iterator[str] = field(init=False)
232
- _timestamp: datetime = field(default_factory=_utils.now_utc)
233
- _buffer: list[str] = field(default_factory=list, init=False)
234
-
235
- def __post_init__(self):
236
- *words, last_word = self._text.split(' ')
237
- words = [f'{word} ' for word in words]
238
- words.append(last_word)
239
- if len(words) == 1 and len(self._text) > 2:
240
- mid = len(self._text) // 2
241
- words = [self._text[:mid], self._text[mid:]]
242
- self._iter = iter(words)
243
-
244
- async def __anext__(self) -> None:
245
- next_str = _utils.sync_anext(self._iter)
246
- response_tokens = _estimate_string_usage(next_str)
247
- self._usage += Usage(response_tokens=response_tokens, total_tokens=response_tokens)
248
- self._buffer.append(next_str)
249
-
250
- def get(self, *, final: bool = False) -> Iterable[str]:
251
- yield from self._buffer
252
- self._buffer.clear()
253
-
254
- def usage(self) -> Usage:
255
- return self._usage
256
-
257
- def timestamp(self) -> datetime:
258
- return self._timestamp
259
-
260
-
261
- @dataclass
262
- class TestStreamStructuredResponse(StreamStructuredResponse):
217
+ class TestStreamedResponse(StreamedResponse):
263
218
  """A structured response that streams test data."""
264
219
 
265
220
  _structured_response: ModelResponse
266
- _usage: Usage
267
- _iter: Iterator[None] = field(default_factory=lambda: iter([None]))
268
- _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
269
-
270
- async def __anext__(self) -> None:
271
- return _utils.sync_anext(self._iter)
221
+ _messages: InitVar[Iterable[ModelMessage]]
272
222
 
273
- def get(self, *, final: bool = False) -> ModelResponse:
274
- return self._structured_response
223
+ _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
275
224
 
276
- def usage(self) -> Usage:
277
- return self._usage
225
+ def __post_init__(self, _messages: Iterable[ModelMessage]):
226
+ self._usage = _estimate_usage(_messages)
227
+
228
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
229
+ for i, part in enumerate(self._structured_response.parts):
230
+ if isinstance(part, TextPart):
231
+ text = part.content
232
+ *words, last_word = text.split(' ')
233
+ words = [f'{word} ' for word in words]
234
+ words.append(last_word)
235
+ if len(words) == 1 and len(text) > 2:
236
+ mid = len(text) // 2
237
+ words = [text[:mid], text[mid:]]
238
+ self._usage += _get_string_usage('')
239
+ yield self._parts_manager.handle_text_delta(vendor_part_id=i, content='')
240
+ for word in words:
241
+ self._usage += _get_string_usage(word)
242
+ yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
243
+ else:
244
+ args = part.args.args_json if isinstance(part.args, ArgsJson) else part.args.args_dict
245
+ yield self._parts_manager.handle_tool_call_part(
246
+ vendor_part_id=i, tool_name=part.tool_name, args=args, tool_call_id=part.tool_call_id
247
+ )
278
248
 
279
249
  def timestamp(self) -> datetime:
280
250
  return self._timestamp
@@ -434,3 +404,8 @@ class _JsonSchemaTestData:
434
404
  rem //= chars
435
405
  s += _chars[self.seed % chars]
436
406
  return s
407
+
408
+
409
+ def _get_string_usage(text: str) -> Usage:
410
+ response_tokens = _estimate_string_tokens(text)
411
+ return Usage(response_tokens=response_tokens, total_tokens=response_tokens)
@@ -10,7 +10,7 @@ from httpx import AsyncClient as AsyncHTTPClient
10
10
  from .._utils import run_in_executor
11
11
  from ..exceptions import UserError
12
12
  from ..tools import ToolDefinition
13
- from . import Model, cached_async_http_client
13
+ from . import Model, cached_async_http_client, check_allow_model_requests
14
14
  from .gemini import GeminiAgentModel, GeminiModelName
15
15
 
16
16
  try:
@@ -114,6 +114,7 @@ class VertexAIModel(Model):
114
114
  allow_text_result: bool,
115
115
  result_tools: list[ToolDefinition],
116
116
  ) -> GeminiAgentModel:
117
+ check_allow_model_requests()
117
118
  url, auth = await self.ainit()
118
119
  return GeminiAgentModel(
119
120
  http_client=self.http_client,