pydantic-ai-slim 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterator, Iterable
3
+ from collections.abc import AsyncIterable, AsyncIterator, Iterable
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
@@ -10,13 +10,14 @@ from typing import Literal, Union, overload
10
10
  from httpx import AsyncClient as AsyncHTTPClient
11
11
  from typing_extensions import assert_never
12
12
 
13
- from .. import UnexpectedModelBehavior, _utils, result
13
+ from .. import UnexpectedModelBehavior, _utils, usage
14
14
  from .._utils import guard_tool_call_id as _guard_tool_call_id
15
15
  from ..messages import (
16
16
  ModelMessage,
17
17
  ModelRequest,
18
18
  ModelResponse,
19
19
  ModelResponsePart,
20
+ ModelResponseStreamEvent,
20
21
  RetryPromptPart,
21
22
  SystemPromptPart,
22
23
  TextPart,
@@ -24,15 +25,12 @@ from ..messages import (
24
25
  ToolReturnPart,
25
26
  UserPromptPart,
26
27
  )
27
- from ..result import Usage
28
28
  from ..settings import ModelSettings
29
29
  from ..tools import ToolDefinition
30
30
  from . import (
31
31
  AgentModel,
32
- EitherStreamedResponse,
33
32
  Model,
34
- StreamStructuredResponse,
35
- StreamTextResponse,
33
+ StreamedResponse,
36
34
  cached_async_http_client,
37
35
  check_allow_model_requests,
38
36
  )
@@ -41,7 +39,6 @@ try:
41
39
  from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream
42
40
  from openai.types import ChatModel, chat
43
41
  from openai.types.chat import ChatCompletionChunk
44
- from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
45
42
  except ImportError as _import_error:
46
43
  raise ImportError(
47
44
  'Please install `openai` to use the OpenAI model, '
@@ -146,14 +143,14 @@ class OpenAIAgentModel(AgentModel):
146
143
 
147
144
  async def request(
148
145
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
149
- ) -> tuple[ModelResponse, result.Usage]:
146
+ ) -> tuple[ModelResponse, usage.Usage]:
150
147
  response = await self._completions_create(messages, False, model_settings)
151
148
  return self._process_response(response), _map_usage(response)
152
149
 
153
150
  @asynccontextmanager
154
151
  async def request_stream(
155
152
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
156
- ) -> AsyncIterator[EitherStreamedResponse]:
153
+ ) -> AsyncIterator[StreamedResponse]:
157
154
  response = await self._completions_create(messages, True, model_settings)
158
155
  async with response:
159
156
  yield await self._process_streamed_response(response)
@@ -214,33 +211,14 @@ class OpenAIAgentModel(AgentModel):
214
211
  return ModelResponse(items, timestamp=timestamp)
215
212
 
216
213
  @staticmethod
217
- async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
214
+ async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
218
215
  """Process a streamed response, and prepare a streaming response to return."""
219
- timestamp: datetime | None = None
220
- start_usage = Usage()
221
- # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
222
- while True:
223
- try:
224
- chunk = await response.__anext__()
225
- except StopAsyncIteration as e:
226
- raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
227
-
228
- timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
229
- start_usage += _map_usage(chunk)
230
-
231
- if chunk.choices:
232
- delta = chunk.choices[0].delta
233
-
234
- if delta.content is not None:
235
- return OpenAIStreamTextResponse(delta.content, response, timestamp, start_usage)
236
- elif delta.tool_calls is not None:
237
- return OpenAIStreamStructuredResponse(
238
- response,
239
- {c.index: c for c in delta.tool_calls},
240
- timestamp,
241
- start_usage,
242
- )
243
- # else continue until we get either delta.content or delta.tool_calls
216
+ peekable_response = _utils.PeekableAsyncStream(response)
217
+ first_chunk = await peekable_response.peek()
218
+ if isinstance(first_chunk, _utils.Unset):
219
+ raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
220
+
221
+ return OpenAIStreamedResponse(peekable_response, datetime.fromtimestamp(first_chunk.created, tz=timezone.utc))
244
222
 
245
223
  @classmethod
246
224
  def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
@@ -295,88 +273,35 @@ class OpenAIAgentModel(AgentModel):
295
273
 
296
274
 
297
275
  @dataclass
298
- class OpenAIStreamTextResponse(StreamTextResponse):
299
- """Implementation of `StreamTextResponse` for OpenAI models."""
300
-
301
- _first: str | None
302
- _response: AsyncStream[ChatCompletionChunk]
303
- _timestamp: datetime
304
- _usage: result.Usage
305
- _buffer: list[str] = field(default_factory=list, init=False)
306
-
307
- async def __anext__(self) -> None:
308
- if self._first is not None:
309
- self._buffer.append(self._first)
310
- self._first = None
311
- return None
312
-
313
- chunk = await self._response.__anext__()
314
- self._usage += _map_usage(chunk)
315
- try:
316
- choice = chunk.choices[0]
317
- except IndexError:
318
- raise StopAsyncIteration()
319
-
320
- # we don't raise StopAsyncIteration on the last chunk because usage comes after this
321
- if choice.finish_reason is None:
322
- assert choice.delta.content is not None, f'Expected delta with content, invalid chunk: {chunk!r}'
323
- if choice.delta.content is not None:
324
- self._buffer.append(choice.delta.content)
325
-
326
- def get(self, *, final: bool = False) -> Iterable[str]:
327
- yield from self._buffer
328
- self._buffer.clear()
329
-
330
- def usage(self) -> Usage:
331
- return self._usage
332
-
333
- def timestamp(self) -> datetime:
334
- return self._timestamp
335
-
336
-
337
- @dataclass
338
- class OpenAIStreamStructuredResponse(StreamStructuredResponse):
339
- """Implementation of `StreamStructuredResponse` for OpenAI models."""
276
+ class OpenAIStreamedResponse(StreamedResponse):
277
+ """Implementation of `StreamedResponse` for OpenAI models."""
340
278
 
341
- _response: AsyncStream[ChatCompletionChunk]
342
- _delta_tool_calls: dict[int, ChoiceDeltaToolCall]
279
+ _response: AsyncIterable[ChatCompletionChunk]
343
280
  _timestamp: datetime
344
- _usage: result.Usage
345
-
346
- async def __anext__(self) -> None:
347
- chunk = await self._response.__anext__()
348
- self._usage += _map_usage(chunk)
349
- try:
350
- choice = chunk.choices[0]
351
- except IndexError:
352
- raise StopAsyncIteration()
353
-
354
- if choice.finish_reason is not None:
355
- raise StopAsyncIteration()
356
-
357
- assert choice.delta.content is None, f'Expected tool calls, got content instead, invalid chunk: {chunk!r}'
358
-
359
- for new in choice.delta.tool_calls or []:
360
- if current := self._delta_tool_calls.get(new.index):
361
- if current.function is None:
362
- current.function = new.function
363
- elif new.function is not None:
364
- current.function.name = _utils.add_optional(current.function.name, new.function.name)
365
- current.function.arguments = _utils.add_optional(current.function.arguments, new.function.arguments)
366
- else:
367
- self._delta_tool_calls[new.index] = new
368
281
 
369
- def get(self, *, final: bool = False) -> ModelResponse:
370
- items: list[ModelResponsePart] = []
371
- for c in self._delta_tool_calls.values():
372
- if f := c.function:
373
- if f.name is not None and f.arguments is not None:
374
- items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
375
-
376
- return ModelResponse(items, timestamp=self._timestamp)
282
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
283
+ async for chunk in self._response:
284
+ self._usage += _map_usage(chunk)
377
285
 
378
- def usage(self) -> Usage:
379
- return self._usage
286
+ try:
287
+ choice = chunk.choices[0]
288
+ except IndexError:
289
+ continue
290
+
291
+ # Handle the text part of the response
292
+ content = choice.delta.content
293
+ if content is not None:
294
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
295
+
296
+ for dtc in choice.delta.tool_calls or []:
297
+ maybe_event = self._parts_manager.handle_tool_call_delta(
298
+ vendor_part_id=dtc.index,
299
+ tool_name=dtc.function and dtc.function.name,
300
+ args=dtc.function and dtc.function.arguments,
301
+ tool_call_id=dtc.id,
302
+ )
303
+ if maybe_event is not None:
304
+ yield maybe_event
380
305
 
381
306
  def timestamp(self) -> datetime:
382
307
  return self._timestamp
@@ -390,19 +315,19 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
390
315
  )
391
316
 
392
317
 
393
- def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> result.Usage:
394
- usage = response.usage
395
- if usage is None:
396
- return result.Usage()
318
+ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usage:
319
+ response_usage = response.usage
320
+ if response_usage is None:
321
+ return usage.Usage()
397
322
  else:
398
323
  details: dict[str, int] = {}
399
- if usage.completion_tokens_details is not None:
400
- details.update(usage.completion_tokens_details.model_dump(exclude_none=True))
401
- if usage.prompt_tokens_details is not None:
402
- details.update(usage.prompt_tokens_details.model_dump(exclude_none=True))
403
- return result.Usage(
404
- request_tokens=usage.prompt_tokens,
405
- response_tokens=usage.completion_tokens,
406
- total_tokens=usage.total_tokens,
324
+ if response_usage.completion_tokens_details is not None:
325
+ details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True))
326
+ if response_usage.prompt_tokens_details is not None:
327
+ details.update(response_usage.prompt_tokens_details.model_dump(exclude_none=True))
328
+ return usage.Usage(
329
+ request_tokens=response_usage.prompt_tokens,
330
+ response_tokens=response_usage.completion_tokens,
331
+ total_tokens=response_usage.total_tokens,
407
332
  details=details,
408
333
  )
@@ -2,21 +2,22 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import re
4
4
  import string
5
- from collections.abc import AsyncIterator, Iterable, Iterator
5
+ from collections.abc import AsyncIterator, Iterable
6
6
  from contextlib import asynccontextmanager
7
- from dataclasses import dataclass, field
7
+ from dataclasses import InitVar, dataclass, field
8
8
  from datetime import date, datetime, timedelta
9
9
  from typing import Any, Literal
10
10
 
11
11
  import pydantic_core
12
- from typing_extensions import assert_never
13
12
 
14
13
  from .. import _utils
15
14
  from ..messages import (
15
+ ArgsJson,
16
16
  ModelMessage,
17
17
  ModelRequest,
18
18
  ModelResponse,
19
19
  ModelResponsePart,
20
+ ModelResponseStreamEvent,
20
21
  RetryPromptPart,
21
22
  TextPart,
22
23
  ToolCallPart,
@@ -27,12 +28,10 @@ from ..settings import ModelSettings
27
28
  from ..tools import ToolDefinition
28
29
  from . import (
29
30
  AgentModel,
30
- EitherStreamedResponse,
31
31
  Model,
32
- StreamStructuredResponse,
33
- StreamTextResponse,
32
+ StreamedResponse,
34
33
  )
35
- from .function import _estimate_string_usage, _estimate_usage # pyright: ignore[reportPrivateUsage]
34
+ from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
36
35
 
37
36
 
38
37
  @dataclass
@@ -141,25 +140,9 @@ class TestAgentModel(AgentModel):
141
140
  @asynccontextmanager
142
141
  async def request_stream(
143
142
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
144
- ) -> AsyncIterator[EitherStreamedResponse]:
145
- msg = self._request(messages, model_settings)
146
- usage = _estimate_usage(messages)
147
-
148
- # TODO: Rework this once we make StreamTextResponse more general
149
- texts: list[str] = []
150
- tool_calls: list[ToolCallPart] = []
151
- for item in msg.parts:
152
- if isinstance(item, TextPart):
153
- texts.append(item.content)
154
- elif isinstance(item, ToolCallPart):
155
- tool_calls.append(item)
156
- else:
157
- assert_never(item)
158
-
159
- if texts:
160
- yield TestStreamTextResponse('\n\n'.join(texts), usage)
161
- else:
162
- yield TestStreamStructuredResponse(msg, usage)
143
+ ) -> AsyncIterator[StreamedResponse]:
144
+ model_response = self._request(messages, model_settings)
145
+ yield TestStreamedResponse(model_response, messages)
163
146
 
164
147
  def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
165
148
  return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
@@ -223,58 +206,37 @@ class TestAgentModel(AgentModel):
223
206
 
224
207
 
225
208
  @dataclass
226
- class TestStreamTextResponse(StreamTextResponse):
227
- """A text response that streams test data."""
228
-
229
- _text: str
230
- _usage: Usage
231
- _iter: Iterator[str] = field(init=False)
232
- _timestamp: datetime = field(default_factory=_utils.now_utc)
233
- _buffer: list[str] = field(default_factory=list, init=False)
234
-
235
- def __post_init__(self):
236
- *words, last_word = self._text.split(' ')
237
- words = [f'{word} ' for word in words]
238
- words.append(last_word)
239
- if len(words) == 1 and len(self._text) > 2:
240
- mid = len(self._text) // 2
241
- words = [self._text[:mid], self._text[mid:]]
242
- self._iter = iter(words)
243
-
244
- async def __anext__(self) -> None:
245
- next_str = _utils.sync_anext(self._iter)
246
- response_tokens = _estimate_string_usage(next_str)
247
- self._usage += Usage(response_tokens=response_tokens, total_tokens=response_tokens)
248
- self._buffer.append(next_str)
249
-
250
- def get(self, *, final: bool = False) -> Iterable[str]:
251
- yield from self._buffer
252
- self._buffer.clear()
253
-
254
- def usage(self) -> Usage:
255
- return self._usage
256
-
257
- def timestamp(self) -> datetime:
258
- return self._timestamp
259
-
260
-
261
- @dataclass
262
- class TestStreamStructuredResponse(StreamStructuredResponse):
209
+ class TestStreamedResponse(StreamedResponse):
263
210
  """A structured response that streams test data."""
264
211
 
265
212
  _structured_response: ModelResponse
266
- _usage: Usage
267
- _iter: Iterator[None] = field(default_factory=lambda: iter([None]))
268
- _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
269
-
270
- async def __anext__(self) -> None:
271
- return _utils.sync_anext(self._iter)
213
+ _messages: InitVar[Iterable[ModelMessage]]
272
214
 
273
- def get(self, *, final: bool = False) -> ModelResponse:
274
- return self._structured_response
215
+ _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
275
216
 
276
- def usage(self) -> Usage:
277
- return self._usage
217
+ def __post_init__(self, _messages: Iterable[ModelMessage]):
218
+ self._usage = _estimate_usage(_messages)
219
+
220
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
221
+ for i, part in enumerate(self._structured_response.parts):
222
+ if isinstance(part, TextPart):
223
+ text = part.content
224
+ *words, last_word = text.split(' ')
225
+ words = [f'{word} ' for word in words]
226
+ words.append(last_word)
227
+ if len(words) == 1 and len(text) > 2:
228
+ mid = len(text) // 2
229
+ words = [text[:mid], text[mid:]]
230
+ self._usage += _get_string_usage('')
231
+ yield self._parts_manager.handle_text_delta(vendor_part_id=i, content='')
232
+ for word in words:
233
+ self._usage += _get_string_usage(word)
234
+ yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
235
+ else:
236
+ args = part.args.args_json if isinstance(part.args, ArgsJson) else part.args.args_dict
237
+ yield self._parts_manager.handle_tool_call_part(
238
+ vendor_part_id=i, tool_name=part.tool_name, args=args, tool_call_id=part.tool_call_id
239
+ )
278
240
 
279
241
  def timestamp(self) -> datetime:
280
242
  return self._timestamp
@@ -434,3 +396,8 @@ class _JsonSchemaTestData:
434
396
  rem //= chars
435
397
  s += _chars[self.seed % chars]
436
398
  return s
399
+
400
+
401
+ def _get_string_usage(text: str) -> Usage:
402
+ response_tokens = _estimate_string_tokens(text)
403
+ return Usage(response_tokens=response_tokens, total_tokens=response_tokens)
@@ -164,7 +164,7 @@ class VertexAIModel(Model):
164
164
  return url, auth
165
165
 
166
166
  def name(self) -> str:
167
- return f'vertexai:{self.model_name}'
167
+ return f'google-vertex:{self.model_name}'
168
168
 
169
169
 
170
170
  # pyright: reportUnknownMemberType=false