pydantic-ai-slim 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/messages.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import uuid
3
4
  from dataclasses import dataclass, field, replace
4
5
  from datetime import datetime
5
6
  from typing import Annotated, Any, Literal, Union, cast, overload
@@ -445,3 +446,33 @@ class PartDeltaEvent:
445
446
 
446
447
  ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
447
448
  """An event in the model response stream, either starting a new part or applying a delta to an existing one."""
449
+
450
+
451
+ @dataclass
452
+ class FunctionToolCallEvent:
453
+ """An event indicating the start to a call to a function tool."""
454
+
455
+ part: ToolCallPart
456
+ """The (function) tool call to make."""
457
+ call_id: str = field(init=False)
458
+ """An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id."""
459
+ event_kind: Literal['function_tool_call'] = 'function_tool_call'
460
+ """Event type identifier, used as a discriminator."""
461
+
462
+ def __post_init__(self):
463
+ self.call_id = self.part.tool_call_id or str(uuid.uuid4())
464
+
465
+
466
+ @dataclass
467
+ class FunctionToolResultEvent:
468
+ """An event indicating the result of a function tool call."""
469
+
470
+ result: ToolReturnPart | RetryPromptPart
471
+ """The result of the call to the function tool."""
472
+ call_id: str
473
+ """An ID used to match the result to its original call."""
474
+ event_kind: Literal['function_tool_result'] = 'function_tool_result'
475
+ """Event type identifier, used as a discriminator."""
476
+
477
+
478
+ HandleResponseEvent = Annotated[Union[FunctionToolCallEvent, FunctionToolResultEvent], pydantic.Discriminator('kind')]
@@ -54,6 +54,8 @@ KnownModelName = Literal[
54
54
  'google-gla:gemini-2.0-flash-exp',
55
55
  'google-gla:gemini-2.0-flash-thinking-exp-01-21',
56
56
  'google-gla:gemini-exp-1206',
57
+ 'google-gla:gemini-2.0-flash',
58
+ 'google-gla:gemini-2.0-flash-lite-preview-02-05',
57
59
  'google-vertex:gemini-1.0-pro',
58
60
  'google-vertex:gemini-1.5-flash',
59
61
  'google-vertex:gemini-1.5-flash-8b',
@@ -61,6 +63,8 @@ KnownModelName = Literal[
61
63
  'google-vertex:gemini-2.0-flash-exp',
62
64
  'google-vertex:gemini-2.0-flash-thinking-exp-01-21',
63
65
  'google-vertex:gemini-exp-1206',
66
+ 'google-vertex:gemini-2.0-flash',
67
+ 'google-vertex:gemini-2.0-flash-lite-preview-02-05',
64
68
  'gpt-3.5-turbo',
65
69
  'gpt-3.5-turbo-0125',
66
70
  'gpt-3.5-turbo-0301',
@@ -173,9 +177,6 @@ class ModelRequestParameters:
173
177
  class Model(ABC):
174
178
  """Abstract class for a model."""
175
179
 
176
- _model_name: str
177
- _system: str | None
178
-
179
180
  @abstractmethod
180
181
  async def request(
181
182
  self,
@@ -201,24 +202,25 @@ class Model(ABC):
201
202
  yield # pragma: no cover
202
203
 
203
204
  @property
205
+ @abstractmethod
204
206
  def model_name(self) -> str:
205
207
  """The model name."""
206
- return self._model_name
208
+ raise NotImplementedError()
207
209
 
208
210
  @property
211
+ @abstractmethod
209
212
  def system(self) -> str | None:
210
213
  """The system / model provider, ex: openai."""
211
- return self._system
214
+ raise NotImplementedError()
212
215
 
213
216
 
214
217
  @dataclass
215
218
  class StreamedResponse(ABC):
216
219
  """Streamed response from an LLM when calling a tool."""
217
220
 
218
- _model_name: str
219
- _usage: Usage = field(default_factory=Usage, init=False)
220
221
  _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
221
222
  _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
223
+ _usage: Usage = field(default_factory=Usage, init=False)
222
224
 
223
225
  def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
224
226
  """Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
@@ -232,6 +234,8 @@ class StreamedResponse(ABC):
232
234
 
233
235
  This method should be implemented by subclasses to translate the vendor-specific stream of events into
234
236
  pydantic_ai-format events.
237
+
238
+ It should use the `_parts_manager` to handle deltas, and should update the `_usage` attributes as it goes.
235
239
  """
236
240
  raise NotImplementedError()
237
241
  # noinspection PyUnreachableCode
@@ -240,17 +244,20 @@ class StreamedResponse(ABC):
240
244
  def get(self) -> ModelResponse:
241
245
  """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
242
246
  return ModelResponse(
243
- parts=self._parts_manager.get_parts(), model_name=self._model_name, timestamp=self.timestamp()
247
+ parts=self._parts_manager.get_parts(), model_name=self.model_name, timestamp=self.timestamp
244
248
  )
245
249
 
246
- def model_name(self) -> str:
247
- """Get the model name of the response."""
248
- return self._model_name
249
-
250
250
  def usage(self) -> Usage:
251
251
  """Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
252
252
  return self._usage
253
253
 
254
+ @property
255
+ @abstractmethod
256
+ def model_name(self) -> str:
257
+ """Get the model name of the response."""
258
+ raise NotImplementedError()
259
+
260
+ @property
254
261
  @abstractmethod
255
262
  def timestamp(self) -> datetime:
256
263
  """Get the timestamp of the response."""
@@ -357,7 +364,6 @@ def infer_model(model: Model | KnownModelName) -> Model:
357
364
  raise UserError(f'Unknown model: {model}')
358
365
 
359
366
 
360
- @cache
361
367
  def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
362
368
  """Cached HTTPX async client so multiple agents and calls can share the same client.
363
369
 
@@ -368,6 +374,16 @@ def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.Asyn
368
374
  The default timeouts match those of OpenAI,
369
375
  see <https://github.com/openai/openai-python/blob/v1.54.4/src/openai/_constants.py#L9>.
370
376
  """
377
+ client = _cached_async_http_client(timeout=timeout, connect=connect)
378
+ if client.is_closed:
379
+ # This happens if the context manager is used, so we need to create a new client.
380
+ _cached_async_http_client.cache_clear()
381
+ client = _cached_async_http_client(timeout=timeout, connect=connect)
382
+ return client
383
+
384
+
385
+ @cache
386
+ def _cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
371
387
  return httpx.AsyncClient(
372
388
  timeout=httpx.Timeout(timeout=timeout, connect=connect),
373
389
  headers={'User-Agent': get_user_agent()},
@@ -162,6 +162,16 @@ class AnthropicModel(Model):
162
162
  async with response:
163
163
  yield await self._process_streamed_response(response)
164
164
 
165
+ @property
166
+ def model_name(self) -> AnthropicModelName:
167
+ """The model name."""
168
+ return self._model_name
169
+
170
+ @property
171
+ def system(self) -> str | None:
172
+ """The system / model provider."""
173
+ return self._system
174
+
165
175
  @overload
166
176
  async def _messages_create(
167
177
  self,
@@ -236,7 +246,7 @@ class AnthropicModel(Model):
236
246
  )
237
247
  )
238
248
 
239
- return ModelResponse(items, model_name=self._model_name)
249
+ return ModelResponse(items, model_name=response.model)
240
250
 
241
251
  async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
242
252
  peekable_response = _utils.PeekableAsyncStream(response)
@@ -262,64 +272,56 @@ class AnthropicModel(Model):
262
272
  anthropic_messages: list[MessageParam] = []
263
273
  for m in messages:
264
274
  if isinstance(m, ModelRequest):
265
- for part in m.parts:
266
- if isinstance(part, SystemPromptPart):
267
- system_prompt += part.content
268
- elif isinstance(part, UserPromptPart):
269
- anthropic_messages.append(MessageParam(role='user', content=part.content))
270
- elif isinstance(part, ToolReturnPart):
271
- anthropic_messages.append(
272
- MessageParam(
273
- role='user',
274
- content=[
275
- ToolResultBlockParam(
276
- tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
277
- type='tool_result',
278
- content=part.model_response_str(),
279
- is_error=False,
280
- )
281
- ],
282
- )
275
+ user_content_params: list[ToolResultBlockParam | TextBlockParam] = []
276
+ for request_part in m.parts:
277
+ if isinstance(request_part, SystemPromptPart):
278
+ system_prompt += request_part.content
279
+ elif isinstance(request_part, UserPromptPart):
280
+ text_block_param = TextBlockParam(type='text', text=request_part.content)
281
+ user_content_params.append(text_block_param)
282
+ elif isinstance(request_part, ToolReturnPart):
283
+ tool_result_block_param = ToolResultBlockParam(
284
+ tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
285
+ type='tool_result',
286
+ content=request_part.model_response_str(),
287
+ is_error=False,
283
288
  )
284
- elif isinstance(part, RetryPromptPart):
285
- if part.tool_name is None:
286
- anthropic_messages.append(MessageParam(role='user', content=part.model_response()))
289
+ user_content_params.append(tool_result_block_param)
290
+ elif isinstance(request_part, RetryPromptPart):
291
+ if request_part.tool_name is None:
292
+ retry_param = TextBlockParam(type='text', text=request_part.model_response())
287
293
  else:
288
- anthropic_messages.append(
289
- MessageParam(
290
- role='user',
291
- content=[
292
- ToolResultBlockParam(
293
- tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
294
- type='tool_result',
295
- content=part.model_response(),
296
- is_error=True,
297
- ),
298
- ],
299
- )
294
+ retry_param = ToolResultBlockParam(
295
+ tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
296
+ type='tool_result',
297
+ content=request_part.model_response(),
298
+ is_error=True,
300
299
  )
300
+ user_content_params.append(retry_param)
301
+ anthropic_messages.append(
302
+ MessageParam(
303
+ role='user',
304
+ content=user_content_params,
305
+ )
306
+ )
301
307
  elif isinstance(m, ModelResponse):
302
- content: list[TextBlockParam | ToolUseBlockParam] = []
303
- for item in m.parts:
304
- if isinstance(item, TextPart):
305
- content.append(TextBlockParam(text=item.content, type='text'))
308
+ assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = []
309
+ for response_part in m.parts:
310
+ if isinstance(response_part, TextPart):
311
+ assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
306
312
  else:
307
- assert isinstance(item, ToolCallPart)
308
- content.append(self._map_tool_call(item))
309
- anthropic_messages.append(MessageParam(role='assistant', content=content))
313
+ tool_use_block_param = ToolUseBlockParam(
314
+ id=_guard_tool_call_id(t=response_part, model_source='Anthropic'),
315
+ type='tool_use',
316
+ name=response_part.tool_name,
317
+ input=response_part.args_as_dict(),
318
+ )
319
+ assistant_content_params.append(tool_use_block_param)
320
+ anthropic_messages.append(MessageParam(role='assistant', content=assistant_content_params))
310
321
  else:
311
322
  assert_never(m)
312
323
  return system_prompt, anthropic_messages
313
324
 
314
- @staticmethod
315
- def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
316
- return ToolUseBlockParam(
317
- id=_guard_tool_call_id(t=t, model_source='Anthropic'),
318
- type='tool_use',
319
- name=t.tool_name,
320
- input=t.args_as_dict(),
321
- )
322
-
323
325
  @staticmethod
324
326
  def _map_tool_definition(f: ToolDefinition) -> ToolParam:
325
327
  return {
@@ -362,6 +364,7 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage
362
364
  class AnthropicStreamedResponse(StreamedResponse):
363
365
  """Implementation of `StreamedResponse` for Anthropic models."""
364
366
 
367
+ _model_name: AnthropicModelName
365
368
  _response: AsyncIterable[RawMessageStreamEvent]
366
369
  _timestamp: datetime
367
370
 
@@ -414,5 +417,12 @@ class AnthropicStreamedResponse(StreamedResponse):
414
417
  elif isinstance(event, (RawContentBlockStopEvent, RawMessageStopEvent)):
415
418
  current_block = None
416
419
 
420
+ @property
421
+ def model_name(self) -> AnthropicModelName:
422
+ """Get the model name of the response."""
423
+ return self._model_name
424
+
425
+ @property
417
426
  def timestamp(self) -> datetime:
427
+ """Get the timestamp of the response."""
418
428
  return self._timestamp
@@ -124,7 +124,7 @@ class CohereModel(Model):
124
124
  assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
125
125
  self.client = cohere_client
126
126
  else:
127
- self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore
127
+ self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client)
128
128
 
129
129
  async def request(
130
130
  self,
@@ -136,6 +136,16 @@ class CohereModel(Model):
136
136
  response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
137
137
  return self._process_response(response), _map_usage(response)
138
138
 
139
+ @property
140
+ def model_name(self) -> CohereModelName:
141
+ """The model name."""
142
+ return self._model_name
143
+
144
+ @property
145
+ def system(self) -> str | None:
146
+ """The system / model provider."""
147
+ return self._system
148
+
139
149
  async def _chat(
140
150
  self,
141
151
  messages: list[ModelMessage],
@@ -109,9 +109,9 @@ class FunctionModel(Model):
109
109
  model_settings,
110
110
  )
111
111
 
112
- assert (
113
- self.stream_function is not None
114
- ), 'FunctionModel must receive a `stream_function` to support streamed requests'
112
+ assert self.stream_function is not None, (
113
+ 'FunctionModel must receive a `stream_function` to support streamed requests'
114
+ )
115
115
 
116
116
  response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info))
117
117
 
@@ -121,6 +121,16 @@ class FunctionModel(Model):
121
121
 
122
122
  yield FunctionStreamedResponse(_model_name=f'function:{self.stream_function.__name__}', _iter=response_stream)
123
123
 
124
+ @property
125
+ def model_name(self) -> str:
126
+ """The model name."""
127
+ return self._model_name
128
+
129
+ @property
130
+ def system(self) -> str | None:
131
+ """The system / model provider."""
132
+ return self._system
133
+
124
134
 
125
135
  @dataclass(frozen=True)
126
136
  class AgentInfo:
@@ -178,6 +188,7 @@ E.g. you need to yield all text or all `DeltaToolCalls`, not mix them.
178
188
  class FunctionStreamedResponse(StreamedResponse):
179
189
  """Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
180
190
 
191
+ _model_name: str
181
192
  _iter: AsyncIterator[str | DeltaToolCalls]
182
193
  _timestamp: datetime = field(default_factory=_utils.now_utc)
183
194
 
@@ -205,7 +216,14 @@ class FunctionStreamedResponse(StreamedResponse):
205
216
  if maybe_event is not None:
206
217
  yield maybe_event
207
218
 
219
+ @property
220
+ def model_name(self) -> str:
221
+ """Get the model name of the response."""
222
+ return self._model_name
223
+
224
+ @property
208
225
  def timestamp(self) -> datetime:
226
+ """Get the timestamp of the response."""
209
227
  return self._timestamp
210
228
 
211
229
 
@@ -47,6 +47,8 @@ LatestGeminiModelNames = Literal[
47
47
  'gemini-2.0-flash-exp',
48
48
  'gemini-2.0-flash-thinking-exp-01-21',
49
49
  'gemini-exp-1206',
50
+ 'gemini-2.0-flash',
51
+ 'gemini-2.0-flash-lite-preview-02-05',
50
52
  ]
51
53
  """Latest Gemini models."""
52
54
 
@@ -147,6 +149,16 @@ class GeminiModel(Model):
147
149
  ) as http_response:
148
150
  yield await self._process_streamed_response(http_response)
149
151
 
152
+ @property
153
+ def model_name(self) -> GeminiModelName:
154
+ """The model name."""
155
+ return self._model_name
156
+
157
+ @property
158
+ def system(self) -> str | None:
159
+ """The system / model provider."""
160
+ return self._system
161
+
150
162
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
151
163
  tools = [_function_from_abstract_tool(t) for t in model_request_parameters.function_tools]
152
164
  if model_request_parameters.result_tools:
@@ -231,7 +243,7 @@ class GeminiModel(Model):
231
243
  else:
232
244
  raise UnexpectedModelBehavior('Content field missing from Gemini response', str(response))
233
245
  parts = response['candidates'][0]['content']['parts']
234
- return _process_response_from_parts(parts, model_name=self._model_name)
246
+ return _process_response_from_parts(parts, model_name=response.get('model_version', self._model_name))
235
247
 
236
248
  async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
237
249
  """Process a streamed response, and prepare a streaming response to return."""
@@ -242,7 +254,7 @@ class GeminiModel(Model):
242
254
  async for chunk in aiter_bytes:
243
255
  content.extend(chunk)
244
256
  responses = _gemini_streamed_response_ta.validate_json(
245
- content,
257
+ _ensure_decodeable(content),
246
258
  experimental_allow_partial='trailing-strings',
247
259
  )
248
260
  if responses:
@@ -313,6 +325,7 @@ class ApiKeyAuth:
313
325
  class GeminiStreamedResponse(StreamedResponse):
314
326
  """Implementation of `StreamedResponse` for the Gemini model."""
315
327
 
328
+ _model_name: GeminiModelName
316
329
  _content: bytearray
317
330
  _stream: AsyncIterator[bytes]
318
331
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
@@ -357,7 +370,7 @@ class GeminiStreamedResponse(StreamedResponse):
357
370
  self._content.extend(chunk)
358
371
 
359
372
  gemini_responses = _gemini_streamed_response_ta.validate_json(
360
- self._content,
373
+ _ensure_decodeable(self._content),
361
374
  experimental_allow_partial='trailing-strings',
362
375
  )
363
376
 
@@ -376,7 +389,14 @@ class GeminiStreamedResponse(StreamedResponse):
376
389
  self._usage += _metadata_as_usage(r)
377
390
  yield r
378
391
 
392
+ @property
393
+ def model_name(self) -> GeminiModelName:
394
+ """Get the model name of the response."""
395
+ return self._model_name
396
+
397
+ @property
379
398
  def timestamp(self) -> datetime:
399
+ """Get the timestamp of the response."""
380
400
  return self._timestamp
381
401
 
382
402
 
@@ -608,6 +628,7 @@ class _GeminiResponse(TypedDict):
608
628
  # usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
609
629
  usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
610
630
  prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
631
+ model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]]
611
632
 
612
633
 
613
634
  class _GeminiCandidates(TypedDict):
@@ -753,3 +774,19 @@ class _GeminiJsonSchema:
753
774
 
754
775
  if items_schema := schema.get('items'): # pragma: no branch
755
776
  self._simplify(items_schema, refs_stack)
777
+
778
+
779
+ def _ensure_decodeable(content: bytearray) -> bytearray:
780
+ """Trim any invalid unicode point bytes off the end of a bytearray.
781
+
782
+ This is necessary before attempting to parse streaming JSON bytes.
783
+
784
+ This is a temporary workaround until https://github.com/pydantic/pydantic-core/issues/1633 is resolved
785
+ """
786
+ while True:
787
+ try:
788
+ content.decode()
789
+ except UnicodeDecodeError:
790
+ content = content[:-1] # this will definitely succeed before we run out of bytes
791
+ else:
792
+ return content
@@ -146,6 +146,16 @@ class GroqModel(Model):
146
146
  async with response:
147
147
  yield await self._process_streamed_response(response)
148
148
 
149
+ @property
150
+ def model_name(self) -> GroqModelName:
151
+ """The model name."""
152
+ return self._model_name
153
+
154
+ @property
155
+ def system(self) -> str | None:
156
+ """The system / model provider."""
157
+ return self._system
158
+
149
159
  @overload
150
160
  async def _completions_create(
151
161
  self,
@@ -212,7 +222,7 @@ class GroqModel(Model):
212
222
  if choice.message.tool_calls is not None:
213
223
  for c in choice.message.tool_calls:
214
224
  items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
215
- return ModelResponse(items, model_name=self._model_name, timestamp=timestamp)
225
+ return ModelResponse(items, model_name=response.model, timestamp=timestamp)
216
226
 
217
227
  async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
218
228
  """Process a streamed response, and prepare a streaming response to return."""
@@ -305,6 +315,7 @@ class GroqModel(Model):
305
315
  class GroqStreamedResponse(StreamedResponse):
306
316
  """Implementation of `StreamedResponse` for Groq models."""
307
317
 
318
+ _model_name: GroqModelName
308
319
  _response: AsyncIterable[ChatCompletionChunk]
309
320
  _timestamp: datetime
310
321
 
@@ -333,7 +344,14 @@ class GroqStreamedResponse(StreamedResponse):
333
344
  if maybe_event is not None:
334
345
  yield maybe_event
335
346
 
347
+ @property
348
+ def model_name(self) -> GroqModelName:
349
+ """Get the model name of the response."""
350
+ return self._model_name
351
+
352
+ @property
336
353
  def timestamp(self) -> datetime:
354
+ """Get the timestamp of the response."""
337
355
  return self._timestamp
338
356
 
339
357