pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -7,20 +7,22 @@ specific LLM being used.
7
7
  from __future__ import annotations as _annotations
8
8
 
9
9
  from abc import ABC, abstractmethod
10
- from collections.abc import AsyncIterator, Iterable, Iterator
10
+ from collections.abc import AsyncIterator, Iterator
11
11
  from contextlib import asynccontextmanager, contextmanager
12
+ from dataclasses import dataclass, field
12
13
  from datetime import datetime
13
14
  from functools import cache
14
- from typing import TYPE_CHECKING, Literal, Union
15
+ from typing import TYPE_CHECKING, Literal
15
16
 
16
17
  import httpx
17
18
 
19
+ from .._parts_manager import ModelResponsePartsManager
18
20
  from ..exceptions import UserError
19
- from ..messages import ModelMessage, ModelResponse
21
+ from ..messages import ModelMessage, ModelResponse, ModelResponseStreamEvent
20
22
  from ..settings import ModelSettings
23
+ from ..usage import Usage
21
24
 
22
25
  if TYPE_CHECKING:
23
- from ..result import Usage
24
26
  from ..tools import ToolDefinition
25
27
 
26
28
 
@@ -70,6 +72,7 @@ KnownModelName = Literal[
70
72
  'ollama:mistral-nemo',
71
73
  'ollama:mixtral',
72
74
  'ollama:phi3',
75
+ 'ollama:phi4',
73
76
  'ollama:qwq',
74
77
  'ollama:qwen',
75
78
  'ollama:qwen2',
@@ -129,88 +132,47 @@ class AgentModel(ABC):
129
132
  @asynccontextmanager
130
133
  async def request_stream(
131
134
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
132
- ) -> AsyncIterator[EitherStreamedResponse]:
135
+ ) -> AsyncIterator[StreamedResponse]:
133
136
  """Make a request to the model and return a streaming response."""
137
+ # This method is not required, but you need to implement it if you want to support streamed responses
134
138
  raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}')
135
139
  # yield is required to make this a generator for type checking
136
140
  # noinspection PyUnreachableCode
137
141
  yield # pragma: no cover
138
142
 
139
143
 
140
- class StreamTextResponse(ABC):
141
- """Streamed response from an LLM when returning text."""
142
-
143
- def __aiter__(self) -> AsyncIterator[None]:
144
- """Stream the response as an async iterable, building up the text as it goes.
145
-
146
- This is an async iterator that yields `None` to avoid doing the work of validating the input and
147
- extracting the text field when it will often be thrown away.
148
- """
149
- return self
150
-
151
- @abstractmethod
152
- async def __anext__(self) -> None:
153
- """Process the next chunk of the response, see above for why this returns `None`."""
154
- raise NotImplementedError()
155
-
156
- @abstractmethod
157
- def get(self, *, final: bool = False) -> Iterable[str]:
158
- """Returns an iterable of text since the last call to `get()` — e.g. the text delta.
159
-
160
- Args:
161
- final: If True, this is the final call, after iteration is complete, the response should be fully validated
162
- and all text extracted.
163
- """
164
- raise NotImplementedError()
165
-
166
- @abstractmethod
167
- def usage(self) -> Usage:
168
- """Return the usage of the request.
169
-
170
- NOTE: this won't return the full usage until the stream is finished.
171
- """
172
- raise NotImplementedError()
173
-
174
- @abstractmethod
175
- def timestamp(self) -> datetime:
176
- """Get the timestamp of the response."""
177
- raise NotImplementedError()
178
-
179
-
180
- class StreamStructuredResponse(ABC):
144
+ @dataclass
145
+ class StreamedResponse(ABC):
181
146
  """Streamed response from an LLM when calling a tool."""
182
147
 
183
- def __aiter__(self) -> AsyncIterator[None]:
184
- """Stream the response as an async iterable, building up the tool call as it goes.
148
+ _usage: Usage = field(default_factory=Usage, init=False)
149
+ _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
150
+ _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
185
151
 
186
- This is an async iterator that yields `None` to avoid doing the work of building the final tool call when
187
- it will often be thrown away.
188
- """
189
- return self
152
+ def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
153
+ """Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
154
+ if self._event_iterator is None:
155
+ self._event_iterator = self._get_event_iterator()
156
+ return self._event_iterator
190
157
 
191
158
  @abstractmethod
192
- async def __anext__(self) -> None:
193
- """Process the next chunk of the response, see above for why this returns `None`."""
194
- raise NotImplementedError()
195
-
196
- @abstractmethod
197
- def get(self, *, final: bool = False) -> ModelResponse:
198
- """Get the `ModelResponse` at this point.
199
-
200
- The `ModelResponse` may or may not be complete, depending on whether the stream is finished.
159
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
160
+ """Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
201
161
 
202
- Args:
203
- final: If True, this is the final call, after iteration is complete, the response should be fully validated.
162
+ This method should be implemented by subclasses to translate the vendor-specific stream of events into
163
+ pydantic_ai-format events.
204
164
  """
205
165
  raise NotImplementedError()
166
+ # noinspection PyUnreachableCode
167
+ yield
206
168
 
207
- @abstractmethod
208
- def usage(self) -> Usage:
209
- """Get the usage of the request.
169
+ def get(self) -> ModelResponse:
170
+ """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
171
+ return ModelResponse(parts=self._parts_manager.get_parts(), timestamp=self.timestamp())
210
172
 
211
- NOTE: this won't return the full usage until the stream is finished.
212
- """
213
- raise NotImplementedError()
173
+ def usage(self) -> Usage:
174
+ """Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
175
+ return self._usage
214
176
 
215
177
  @abstractmethod
216
178
  def timestamp(self) -> datetime:
@@ -218,9 +180,6 @@ class StreamStructuredResponse(ABC):
218
180
  raise NotImplementedError()
219
181
 
220
182
 
221
- EitherStreamedResponse = Union[StreamTextResponse, StreamStructuredResponse]
222
-
223
-
224
183
  ALLOW_MODEL_REQUESTS = True
225
184
  """Whether to allow requests to models.
226
185
 
@@ -8,7 +8,7 @@ from typing import Any, Literal, Union, cast, overload
8
8
  from httpx import AsyncClient as AsyncHTTPClient
9
9
  from typing_extensions import assert_never
10
10
 
11
- from .. import result
11
+ from .. import usage
12
12
  from .._utils import guard_tool_call_id as _guard_tool_call_id
13
13
  from ..messages import (
14
14
  ArgsDict,
@@ -27,8 +27,8 @@ from ..settings import ModelSettings
27
27
  from ..tools import ToolDefinition
28
28
  from . import (
29
29
  AgentModel,
30
- EitherStreamedResponse,
31
30
  Model,
31
+ StreamedResponse,
32
32
  cached_async_http_client,
33
33
  check_allow_model_requests,
34
34
  )
@@ -158,14 +158,14 @@ class AnthropicAgentModel(AgentModel):
158
158
 
159
159
  async def request(
160
160
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
161
- ) -> tuple[ModelResponse, result.Usage]:
161
+ ) -> tuple[ModelResponse, usage.Usage]:
162
162
  response = await self._messages_create(messages, False, model_settings)
163
163
  return self._process_response(response), _map_usage(response)
164
164
 
165
165
  @asynccontextmanager
166
166
  async def request_stream(
167
167
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
168
- ) -> AsyncIterator[EitherStreamedResponse]:
168
+ ) -> AsyncIterator[StreamedResponse]:
169
169
  response = await self._messages_create(messages, True, model_settings)
170
170
  async with response:
171
171
  yield await self._process_streamed_response(response)
@@ -216,28 +216,28 @@ class AnthropicAgentModel(AgentModel):
216
216
  items: list[ModelResponsePart] = []
217
217
  for item in response.content:
218
218
  if isinstance(item, TextBlock):
219
- items.append(TextPart(item.text))
219
+ items.append(TextPart(content=item.text))
220
220
  else:
221
221
  assert isinstance(item, ToolUseBlock), 'unexpected item type'
222
222
  items.append(
223
223
  ToolCallPart.from_raw_args(
224
- item.name,
225
- cast(dict[str, Any], item.input),
226
- item.id,
224
+ tool_name=item.name,
225
+ args=cast(dict[str, Any], item.input),
226
+ tool_call_id=item.id,
227
227
  )
228
228
  )
229
229
 
230
230
  return ModelResponse(items)
231
231
 
232
232
  @staticmethod
233
- async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> EitherStreamedResponse:
233
+ async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
234
234
  """TODO: Process a streamed response, and prepare a streaming response to return."""
235
235
  # We don't yet support streamed responses from Anthropic, so we raise an error here for now.
236
236
  # Streamed responses will be supported in a future release.
237
237
 
238
238
  raise RuntimeError('Streamed responses are not yet supported for Anthropic models.')
239
239
 
240
- # Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamStructuredResponse
240
+ # Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamedResponse
241
241
  # depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
242
242
  # RawMessageStartEvent
243
243
  # RawMessageDeltaEvent
@@ -315,30 +315,30 @@ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
315
315
  )
316
316
 
317
317
 
318
- def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> result.Usage:
318
+ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage:
319
319
  if isinstance(message, AnthropicMessage):
320
- usage = message.usage
320
+ response_usage = message.usage
321
321
  else:
322
322
  if isinstance(message, RawMessageStartEvent):
323
- usage = message.message.usage
323
+ response_usage = message.message.usage
324
324
  elif isinstance(message, RawMessageDeltaEvent):
325
- usage = message.usage
325
+ response_usage = message.usage
326
326
  else:
327
327
  # No usage information provided in:
328
328
  # - RawMessageStopEvent
329
329
  # - RawContentBlockStartEvent
330
330
  # - RawContentBlockDeltaEvent
331
331
  # - RawContentBlockStopEvent
332
- usage = None
332
+ response_usage = None
333
333
 
334
- if usage is None:
335
- return result.Usage()
334
+ if response_usage is None:
335
+ return usage.Usage()
336
336
 
337
- request_tokens = getattr(usage, 'input_tokens', None)
337
+ request_tokens = getattr(response_usage, 'input_tokens', None)
338
338
 
339
- return result.Usage(
339
+ return usage.Usage(
340
340
  # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
341
341
  request_tokens=request_tokens,
342
- response_tokens=usage.output_tokens,
343
- total_tokens=(request_tokens or 0) + usage.output_tokens,
342
+ response_tokens=response_usage.output_tokens,
343
+ total_tokens=(request_tokens or 0) + response_usage.output_tokens,
344
344
  )
@@ -7,16 +7,17 @@ from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field, replace
8
8
  from datetime import datetime
9
9
  from itertools import chain
10
- from typing import Callable, Union, cast
10
+ from typing import Callable, Union
11
11
 
12
12
  from typing_extensions import TypeAlias, assert_never, overload
13
13
 
14
- from .. import _utils, result
14
+ from .. import _utils, usage
15
+ from .._utils import PeekableAsyncStream
15
16
  from ..messages import (
16
17
  ModelMessage,
17
18
  ModelRequest,
18
19
  ModelResponse,
19
- ModelResponsePart,
20
+ ModelResponseStreamEvent,
20
21
  RetryPromptPart,
21
22
  SystemPromptPart,
22
23
  TextPart,
@@ -26,7 +27,7 @@ from ..messages import (
26
27
  )
27
28
  from ..settings import ModelSettings
28
29
  from ..tools import ToolDefinition
29
- from . import AgentModel, EitherStreamedResponse, Model, StreamStructuredResponse, StreamTextResponse
30
+ from . import AgentModel, Model, StreamedResponse
30
31
 
31
32
 
32
33
  @dataclass(init=False)
@@ -142,7 +143,7 @@ class FunctionAgentModel(AgentModel):
142
143
 
143
144
  async def request(
144
145
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
145
- ) -> tuple[ModelResponse, result.Usage]:
146
+ ) -> tuple[ModelResponse, usage.Usage]:
146
147
  agent_info = replace(self.agent_info, model_settings=model_settings)
147
148
 
148
149
  assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
@@ -158,90 +159,55 @@ class FunctionAgentModel(AgentModel):
158
159
  @asynccontextmanager
159
160
  async def request_stream(
160
161
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
161
- ) -> AsyncIterator[EitherStreamedResponse]:
162
+ ) -> AsyncIterator[StreamedResponse]:
162
163
  assert (
163
164
  self.stream_function is not None
164
165
  ), 'FunctionModel must receive a `stream_function` to support streamed requests'
165
- response_stream = self.stream_function(messages, self.agent_info)
166
- try:
167
- first = await response_stream.__anext__()
168
- except StopAsyncIteration as e:
169
- raise ValueError('Stream function must return at least one item') from e
170
-
171
- if isinstance(first, str):
172
- text_stream = cast(AsyncIterator[str], response_stream)
173
- yield FunctionStreamTextResponse(first, text_stream)
174
- else:
175
- structured_stream = cast(AsyncIterator[DeltaToolCalls], response_stream)
176
- yield FunctionStreamStructuredResponse(first, structured_stream)
177
-
178
-
179
- @dataclass
180
- class FunctionStreamTextResponse(StreamTextResponse):
181
- """Implementation of `StreamTextResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
182
-
183
- _next: str | None
184
- _iter: AsyncIterator[str]
185
- _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
186
- _buffer: list[str] = field(default_factory=list, init=False)
187
-
188
- async def __anext__(self) -> None:
189
- if self._next is not None:
190
- self._buffer.append(self._next)
191
- self._next = None
192
- else:
193
- self._buffer.append(await self._iter.__anext__())
194
-
195
- def get(self, *, final: bool = False) -> Iterable[str]:
196
- yield from self._buffer
197
- self._buffer.clear()
166
+ response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info))
198
167
 
199
- def usage(self) -> result.Usage:
200
- return result.Usage()
168
+ first = await response_stream.peek()
169
+ if isinstance(first, _utils.Unset):
170
+ raise ValueError('Stream function must return at least one item')
201
171
 
202
- def timestamp(self) -> datetime:
203
- return self._timestamp
172
+ yield FunctionStreamedResponse(response_stream)
204
173
 
205
174
 
206
175
  @dataclass
207
- class FunctionStreamStructuredResponse(StreamStructuredResponse):
208
- """Implementation of `StreamStructuredResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
176
+ class FunctionStreamedResponse(StreamedResponse):
177
+ """Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
209
178
 
210
- _next: DeltaToolCalls | None
211
- _iter: AsyncIterator[DeltaToolCalls]
212
- _delta_tool_calls: dict[int, DeltaToolCall] = field(default_factory=dict)
179
+ _iter: AsyncIterator[str | DeltaToolCalls]
213
180
  _timestamp: datetime = field(default_factory=_utils.now_utc)
214
181
 
215
- async def __anext__(self) -> None:
216
- if self._next is not None:
217
- tool_call = self._next
218
- self._next = None
219
- else:
220
- tool_call = await self._iter.__anext__()
182
+ def __post_init__(self):
183
+ self._usage += _estimate_usage([])
221
184
 
222
- for key, new in tool_call.items():
223
- if current := self._delta_tool_calls.get(key):
224
- current.name = _utils.add_optional(current.name, new.name)
225
- current.json_args = _utils.add_optional(current.json_args, new.json_args)
185
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
186
+ async for item in self._iter:
187
+ if isinstance(item, str):
188
+ response_tokens = _estimate_string_tokens(item)
189
+ self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
190
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
226
191
  else:
227
- self._delta_tool_calls[key] = new
228
-
229
- def get(self, *, final: bool = False) -> ModelResponse:
230
- calls: list[ModelResponsePart] = []
231
- for c in self._delta_tool_calls.values():
232
- if c.name is not None and c.json_args is not None:
233
- calls.append(ToolCallPart.from_raw_args(c.name, c.json_args))
234
-
235
- return ModelResponse(calls, timestamp=self._timestamp)
236
-
237
- def usage(self) -> result.Usage:
238
- return _estimate_usage([self.get()])
192
+ delta_tool_calls = item
193
+ for dtc_index, delta_tool_call in delta_tool_calls.items():
194
+ if delta_tool_call.json_args:
195
+ response_tokens = _estimate_string_tokens(delta_tool_call.json_args)
196
+ self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
197
+ maybe_event = self._parts_manager.handle_tool_call_delta(
198
+ vendor_part_id=dtc_index,
199
+ tool_name=delta_tool_call.name,
200
+ args=delta_tool_call.json_args,
201
+ tool_call_id=None,
202
+ )
203
+ if maybe_event is not None:
204
+ yield maybe_event
239
205
 
240
206
  def timestamp(self) -> datetime:
241
207
  return self._timestamp
242
208
 
243
209
 
244
- def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
210
+ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
245
211
  """Very rough guesstimate of the token usage associated with a series of messages.
246
212
 
247
213
  This is designed to be used solely to give plausible numbers for testing!
@@ -253,28 +219,30 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
253
219
  if isinstance(message, ModelRequest):
254
220
  for part in message.parts:
255
221
  if isinstance(part, (SystemPromptPart, UserPromptPart)):
256
- request_tokens += _estimate_string_usage(part.content)
222
+ request_tokens += _estimate_string_tokens(part.content)
257
223
  elif isinstance(part, ToolReturnPart):
258
- request_tokens += _estimate_string_usage(part.model_response_str())
224
+ request_tokens += _estimate_string_tokens(part.model_response_str())
259
225
  elif isinstance(part, RetryPromptPart):
260
- request_tokens += _estimate_string_usage(part.model_response())
226
+ request_tokens += _estimate_string_tokens(part.model_response())
261
227
  else:
262
228
  assert_never(part)
263
229
  elif isinstance(message, ModelResponse):
264
230
  for part in message.parts:
265
231
  if isinstance(part, TextPart):
266
- response_tokens += _estimate_string_usage(part.content)
232
+ response_tokens += _estimate_string_tokens(part.content)
267
233
  elif isinstance(part, ToolCallPart):
268
234
  call = part
269
- response_tokens += 1 + _estimate_string_usage(call.args_as_json_str())
235
+ response_tokens += 1 + _estimate_string_tokens(call.args_as_json_str())
270
236
  else:
271
237
  assert_never(part)
272
238
  else:
273
239
  assert_never(message)
274
- return result.Usage(
240
+ return usage.Usage(
275
241
  request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens
276
242
  )
277
243
 
278
244
 
279
- def _estimate_string_usage(content: str) -> int:
280
- return len(re.split(r'[\s",.:]+', content))
245
+ def _estimate_string_tokens(content: str) -> int:
246
+ if not content:
247
+ return 0
248
+ return len(re.split(r'[\s",.:]+', content.strip()))