pydantic-ai-slim 0.0.24__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/messages.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import uuid
3
4
  from dataclasses import dataclass, field, replace
4
5
  from datetime import datetime
5
6
  from typing import Annotated, Any, Literal, Union, cast, overload
@@ -445,3 +446,33 @@ class PartDeltaEvent:
445
446
 
446
447
  ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
447
448
  """An event in the model response stream, either starting a new part or applying a delta to an existing one."""
449
+
450
+
451
+ @dataclass
452
+ class FunctionToolCallEvent:
453
+ """An event indicating the start to a call to a function tool."""
454
+
455
+ part: ToolCallPart
456
+ """The (function) tool call to make."""
457
+ call_id: str = field(init=False)
458
+ """An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id."""
459
+ event_kind: Literal['function_tool_call'] = 'function_tool_call'
460
+ """Event type identifier, used as a discriminator."""
461
+
462
+ def __post_init__(self):
463
+ self.call_id = self.part.tool_call_id or str(uuid.uuid4())
464
+
465
+
466
+ @dataclass
467
+ class FunctionToolResultEvent:
468
+ """An event indicating the result of a function tool call."""
469
+
470
+ result: ToolReturnPart | RetryPromptPart
471
+ """The result of the call to the function tool."""
472
+ call_id: str
473
+ """An ID used to match the result to its original call."""
474
+ event_kind: Literal['function_tool_result'] = 'function_tool_result'
475
+ """Event type identifier, used as a discriminator."""
476
+
477
+
478
+ HandleResponseEvent = Annotated[Union[FunctionToolCallEvent, FunctionToolResultEvent], pydantic.Discriminator('kind')]
@@ -234,6 +234,8 @@ class StreamedResponse(ABC):
234
234
 
235
235
  This method should be implemented by subclasses to translate the vendor-specific stream of events into
236
236
  pydantic_ai-format events.
237
+
238
+ It should use the `_parts_manager` to handle deltas, and should update the `_usage` attributes as it goes.
237
239
  """
238
240
  raise NotImplementedError()
239
241
  # noinspection PyUnreachableCode
@@ -362,7 +364,6 @@ def infer_model(model: Model | KnownModelName) -> Model:
362
364
  raise UserError(f'Unknown model: {model}')
363
365
 
364
366
 
365
- @cache
366
367
  def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
367
368
  """Cached HTTPX async client so multiple agents and calls can share the same client.
368
369
 
@@ -373,6 +374,16 @@ def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.Asyn
373
374
  The default timeouts match those of OpenAI,
374
375
  see <https://github.com/openai/openai-python/blob/v1.54.4/src/openai/_constants.py#L9>.
375
376
  """
377
+ client = _cached_async_http_client(timeout=timeout, connect=connect)
378
+ if client.is_closed:
379
+ # This happens if the context manager is used, so we need to create a new client.
380
+ _cached_async_http_client.cache_clear()
381
+ client = _cached_async_http_client(timeout=timeout, connect=connect)
382
+ return client
383
+
384
+
385
+ @cache
386
+ def _cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
376
387
  return httpx.AsyncClient(
377
388
  timeout=httpx.Timeout(timeout=timeout, connect=connect),
378
389
  headers={'User-Agent': get_user_agent()},
@@ -272,64 +272,56 @@ class AnthropicModel(Model):
272
272
  anthropic_messages: list[MessageParam] = []
273
273
  for m in messages:
274
274
  if isinstance(m, ModelRequest):
275
- for part in m.parts:
276
- if isinstance(part, SystemPromptPart):
277
- system_prompt += part.content
278
- elif isinstance(part, UserPromptPart):
279
- anthropic_messages.append(MessageParam(role='user', content=part.content))
280
- elif isinstance(part, ToolReturnPart):
281
- anthropic_messages.append(
282
- MessageParam(
283
- role='user',
284
- content=[
285
- ToolResultBlockParam(
286
- tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
287
- type='tool_result',
288
- content=part.model_response_str(),
289
- is_error=False,
290
- )
291
- ],
292
- )
275
+ user_content_params: list[ToolResultBlockParam | TextBlockParam] = []
276
+ for request_part in m.parts:
277
+ if isinstance(request_part, SystemPromptPart):
278
+ system_prompt += request_part.content
279
+ elif isinstance(request_part, UserPromptPart):
280
+ text_block_param = TextBlockParam(type='text', text=request_part.content)
281
+ user_content_params.append(text_block_param)
282
+ elif isinstance(request_part, ToolReturnPart):
283
+ tool_result_block_param = ToolResultBlockParam(
284
+ tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
285
+ type='tool_result',
286
+ content=request_part.model_response_str(),
287
+ is_error=False,
293
288
  )
294
- elif isinstance(part, RetryPromptPart):
295
- if part.tool_name is None:
296
- anthropic_messages.append(MessageParam(role='user', content=part.model_response()))
289
+ user_content_params.append(tool_result_block_param)
290
+ elif isinstance(request_part, RetryPromptPart):
291
+ if request_part.tool_name is None:
292
+ retry_param = TextBlockParam(type='text', text=request_part.model_response())
297
293
  else:
298
- anthropic_messages.append(
299
- MessageParam(
300
- role='user',
301
- content=[
302
- ToolResultBlockParam(
303
- tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
304
- type='tool_result',
305
- content=part.model_response(),
306
- is_error=True,
307
- ),
308
- ],
309
- )
294
+ retry_param = ToolResultBlockParam(
295
+ tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
296
+ type='tool_result',
297
+ content=request_part.model_response(),
298
+ is_error=True,
310
299
  )
300
+ user_content_params.append(retry_param)
301
+ anthropic_messages.append(
302
+ MessageParam(
303
+ role='user',
304
+ content=user_content_params,
305
+ )
306
+ )
311
307
  elif isinstance(m, ModelResponse):
312
- content: list[TextBlockParam | ToolUseBlockParam] = []
313
- for item in m.parts:
314
- if isinstance(item, TextPart):
315
- content.append(TextBlockParam(text=item.content, type='text'))
308
+ assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = []
309
+ for response_part in m.parts:
310
+ if isinstance(response_part, TextPart):
311
+ assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
316
312
  else:
317
- assert isinstance(item, ToolCallPart)
318
- content.append(self._map_tool_call(item))
319
- anthropic_messages.append(MessageParam(role='assistant', content=content))
313
+ tool_use_block_param = ToolUseBlockParam(
314
+ id=_guard_tool_call_id(t=response_part, model_source='Anthropic'),
315
+ type='tool_use',
316
+ name=response_part.tool_name,
317
+ input=response_part.args_as_dict(),
318
+ )
319
+ assistant_content_params.append(tool_use_block_param)
320
+ anthropic_messages.append(MessageParam(role='assistant', content=assistant_content_params))
320
321
  else:
321
322
  assert_never(m)
322
323
  return system_prompt, anthropic_messages
323
324
 
324
- @staticmethod
325
- def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
326
- return ToolUseBlockParam(
327
- id=_guard_tool_call_id(t=t, model_source='Anthropic'),
328
- type='tool_use',
329
- name=t.tool_name,
330
- input=t.args_as_dict(),
331
- )
332
-
333
325
  @staticmethod
334
326
  def _map_tool_definition(f: ToolDefinition) -> ToolParam:
335
327
  return {
@@ -124,7 +124,7 @@ class CohereModel(Model):
124
124
  assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
125
125
  self.client = cohere_client
126
126
  else:
127
- self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore
127
+ self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client)
128
128
 
129
129
  async def request(
130
130
  self,
@@ -109,9 +109,9 @@ class FunctionModel(Model):
109
109
  model_settings,
110
110
  )
111
111
 
112
- assert (
113
- self.stream_function is not None
114
- ), 'FunctionModel must receive a `stream_function` to support streamed requests'
112
+ assert self.stream_function is not None, (
113
+ 'FunctionModel must receive a `stream_function` to support streamed requests'
114
+ )
115
115
 
116
116
  response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info))
117
117
 
@@ -254,7 +254,7 @@ class GeminiModel(Model):
254
254
  async for chunk in aiter_bytes:
255
255
  content.extend(chunk)
256
256
  responses = _gemini_streamed_response_ta.validate_json(
257
- content,
257
+ _ensure_decodeable(content),
258
258
  experimental_allow_partial='trailing-strings',
259
259
  )
260
260
  if responses:
@@ -370,7 +370,7 @@ class GeminiStreamedResponse(StreamedResponse):
370
370
  self._content.extend(chunk)
371
371
 
372
372
  gemini_responses = _gemini_streamed_response_ta.validate_json(
373
- self._content,
373
+ _ensure_decodeable(self._content),
374
374
  experimental_allow_partial='trailing-strings',
375
375
  )
376
376
 
@@ -774,3 +774,19 @@ class _GeminiJsonSchema:
774
774
 
775
775
  if items_schema := schema.get('items'): # pragma: no branch
776
776
  self._simplify(items_schema, refs_stack)
777
+
778
+
779
+ def _ensure_decodeable(content: bytearray) -> bytearray:
780
+ """Trim any invalid unicode point bytes off the end of a bytearray.
781
+
782
+ This is necessary before attempting to parse streaming JSON bytes.
783
+
784
+ This is a temporary workaround until https://github.com/pydantic/pydantic-core/issues/1633 is resolved
785
+ """
786
+ while True:
787
+ try:
788
+ content.decode()
789
+ except UnicodeDecodeError:
790
+ content = content[:-1] # this will definitely succeed before we run out of bytes
791
+ else:
792
+ return content
@@ -0,0 +1,225 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import AsyncIterator, Iterator
4
+ from contextlib import asynccontextmanager, contextmanager
5
+ from dataclasses import dataclass, field
6
+ from functools import partial
7
+ from typing import Any, Callable, Literal
8
+
9
+ import logfire_api
10
+ from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
11
+ from opentelemetry.trace import Tracer, TracerProvider, get_tracer_provider
12
+
13
+ from ..messages import (
14
+ ModelMessage,
15
+ ModelRequest,
16
+ ModelRequestPart,
17
+ ModelResponse,
18
+ RetryPromptPart,
19
+ SystemPromptPart,
20
+ TextPart,
21
+ ToolCallPart,
22
+ ToolReturnPart,
23
+ UserPromptPart,
24
+ )
25
+ from ..settings import ModelSettings
26
+ from ..usage import Usage
27
+ from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
28
+ from .wrapper import WrapperModel
29
+
30
+ MODEL_SETTING_ATTRIBUTES: tuple[
31
+ Literal[
32
+ 'max_tokens',
33
+ 'top_p',
34
+ 'seed',
35
+ 'temperature',
36
+ 'presence_penalty',
37
+ 'frequency_penalty',
38
+ ],
39
+ ...,
40
+ ] = (
41
+ 'max_tokens',
42
+ 'top_p',
43
+ 'seed',
44
+ 'temperature',
45
+ 'presence_penalty',
46
+ 'frequency_penalty',
47
+ )
48
+
49
+ NOT_GIVEN = object()
50
+
51
+
52
+ @dataclass
53
+ class InstrumentedModel(WrapperModel):
54
+ """Model which is instrumented with logfire."""
55
+
56
+ tracer: Tracer = field(repr=False)
57
+ event_logger: EventLogger = field(repr=False)
58
+
59
+ def __init__(
60
+ self,
61
+ wrapped: Model | KnownModelName,
62
+ tracer_provider: TracerProvider | None = None,
63
+ event_logger_provider: EventLoggerProvider | None = None,
64
+ ):
65
+ super().__init__(wrapped)
66
+ tracer_provider = tracer_provider or get_tracer_provider()
67
+ event_logger_provider = event_logger_provider or get_event_logger_provider()
68
+ self.tracer = tracer_provider.get_tracer('pydantic-ai')
69
+ self.event_logger = event_logger_provider.get_event_logger('pydantic-ai')
70
+
71
+ @classmethod
72
+ def from_logfire(
73
+ cls,
74
+ wrapped: Model | KnownModelName,
75
+ logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE,
76
+ ) -> InstrumentedModel:
77
+ if hasattr(logfire_instance.config, 'get_event_logger_provider'):
78
+ event_provider = logfire_instance.config.get_event_logger_provider()
79
+ else:
80
+ event_provider = None
81
+ tracer_provider = logfire_instance.config.get_tracer_provider()
82
+ return cls(wrapped, tracer_provider, event_provider)
83
+
84
+ async def request(
85
+ self,
86
+ messages: list[ModelMessage],
87
+ model_settings: ModelSettings | None,
88
+ model_request_parameters: ModelRequestParameters,
89
+ ) -> tuple[ModelResponse, Usage]:
90
+ with self._instrument(messages, model_settings) as finish:
91
+ response, usage = await super().request(messages, model_settings, model_request_parameters)
92
+ finish(response, usage)
93
+ return response, usage
94
+
95
+ @asynccontextmanager
96
+ async def request_stream(
97
+ self,
98
+ messages: list[ModelMessage],
99
+ model_settings: ModelSettings | None,
100
+ model_request_parameters: ModelRequestParameters,
101
+ ) -> AsyncIterator[StreamedResponse]:
102
+ with self._instrument(messages, model_settings) as finish:
103
+ response_stream: StreamedResponse | None = None
104
+ try:
105
+ async with super().request_stream(
106
+ messages, model_settings, model_request_parameters
107
+ ) as response_stream:
108
+ yield response_stream
109
+ finally:
110
+ if response_stream:
111
+ finish(response_stream.get(), response_stream.usage())
112
+
113
+ @contextmanager
114
+ def _instrument(
115
+ self,
116
+ messages: list[ModelMessage],
117
+ model_settings: ModelSettings | None,
118
+ ) -> Iterator[Callable[[ModelResponse, Usage], None]]:
119
+ operation = 'chat'
120
+ model_name = self.model_name
121
+ span_name = f'{operation} {model_name}'
122
+ system = getattr(self.wrapped, 'system', '') or self.wrapped.__class__.__name__.removesuffix('Model').lower()
123
+ system = {'google-gla': 'gemini', 'google-vertex': 'vertex_ai', 'mistral': 'mistral_ai'}.get(system, system)
124
+ # TODO Missing attributes:
125
+ # - server.address: requires a Model.base_url abstract method or similar
126
+ # - server.port: to parse from the base_url
127
+ # - error.type: unclear if we should do something here or just always rely on span exceptions
128
+ # - gen_ai.request.stop_sequences/top_k: model_settings doesn't include these
129
+ attributes: dict[str, Any] = {
130
+ 'gen_ai.operation.name': operation,
131
+ 'gen_ai.system': system,
132
+ 'gen_ai.request.model': model_name,
133
+ }
134
+
135
+ if model_settings:
136
+ for key in MODEL_SETTING_ATTRIBUTES:
137
+ if (value := model_settings.get(key, NOT_GIVEN)) is not NOT_GIVEN:
138
+ attributes[f'gen_ai.request.{key}'] = value
139
+
140
+ emit_event = partial(self._emit_event, system)
141
+
142
+ with self.tracer.start_as_current_span(span_name, attributes=attributes) as span:
143
+ if span.is_recording():
144
+ for message in messages:
145
+ if isinstance(message, ModelRequest):
146
+ for part in message.parts:
147
+ event_name, body = _request_part_body(part)
148
+ if event_name:
149
+ emit_event(event_name, body)
150
+ elif isinstance(message, ModelResponse):
151
+ for body in _response_bodies(message):
152
+ emit_event('gen_ai.assistant.message', body)
153
+
154
+ def finish(response: ModelResponse, usage: Usage):
155
+ if not span.is_recording():
156
+ return
157
+
158
+ for response_body in _response_bodies(response):
159
+ if response_body:
160
+ emit_event(
161
+ 'gen_ai.choice',
162
+ {
163
+ # TODO finish_reason
164
+ 'index': 0,
165
+ 'message': response_body,
166
+ },
167
+ )
168
+ span.set_attributes(
169
+ {
170
+ k: v
171
+ for k, v in {
172
+ # TODO finish_reason (https://github.com/open-telemetry/semantic-conventions/issues/1277), id
173
+ # https://github.com/pydantic/pydantic-ai/issues/886
174
+ 'gen_ai.response.model': response.model_name or model_name,
175
+ 'gen_ai.usage.input_tokens': usage.request_tokens,
176
+ 'gen_ai.usage.output_tokens': usage.response_tokens,
177
+ }.items()
178
+ if v is not None
179
+ }
180
+ )
181
+
182
+ yield finish
183
+
184
+ def _emit_event(self, system: str, event_name: str, body: dict[str, Any]) -> None:
185
+ self.event_logger.emit(Event(event_name, body=body, attributes={'gen_ai.system': system}))
186
+
187
+
188
+ def _request_part_body(part: ModelRequestPart) -> tuple[str, dict[str, Any]]:
189
+ if isinstance(part, SystemPromptPart):
190
+ return 'gen_ai.system.message', {'content': part.content, 'role': 'system'}
191
+ elif isinstance(part, UserPromptPart):
192
+ return 'gen_ai.user.message', {'content': part.content, 'role': 'user'}
193
+ elif isinstance(part, ToolReturnPart):
194
+ return 'gen_ai.tool.message', {'content': part.content, 'role': 'tool', 'id': part.tool_call_id}
195
+ elif isinstance(part, RetryPromptPart):
196
+ if part.tool_name is None:
197
+ return 'gen_ai.user.message', {'content': part.model_response(), 'role': 'user'}
198
+ else:
199
+ return 'gen_ai.tool.message', {'content': part.model_response(), 'role': 'tool', 'id': part.tool_call_id}
200
+ else:
201
+ return '', {}
202
+
203
+
204
+ def _response_bodies(message: ModelResponse) -> list[dict[str, Any]]:
205
+ body: dict[str, Any] = {'role': 'assistant'}
206
+ result = [body]
207
+ for part in message.parts:
208
+ if isinstance(part, ToolCallPart):
209
+ body.setdefault('tool_calls', []).append(
210
+ {
211
+ 'id': part.tool_call_id,
212
+ 'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888
213
+ 'function': {
214
+ 'name': part.tool_name,
215
+ 'arguments': part.args,
216
+ },
217
+ }
218
+ )
219
+ elif isinstance(part, TextPart):
220
+ if body.get('content'):
221
+ body = {'role': 'assistant'}
222
+ result.append(body)
223
+ body['content'] = part.content
224
+
225
+ return result
@@ -134,9 +134,6 @@ class MistralModel(Model):
134
134
  api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
135
135
  self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
136
136
 
137
- def name(self) -> str:
138
- return f'mistral:{self._model_name}'
139
-
140
137
  async def request(
141
138
  self,
142
139
  messages: list[ModelMessage],
@@ -119,9 +119,9 @@ class OpenAIModel(Model):
119
119
  """
120
120
  self._model_name = model_name
121
121
  # This is a workaround for the OpenAI client requiring an API key, whilst locally served,
122
- # openai compatible models do not always need an API key.
122
+ # openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
123
123
  if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None:
124
- api_key = ''
124
+ api_key = 'api-key-not-set'
125
125
 
126
126
  if openai_client is not None:
127
127
  assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
@@ -135,9 +135,6 @@ class OpenAIModel(Model):
135
135
  self.system_prompt_role = system_prompt_role
136
136
  self._system = system
137
137
 
138
- def name(self) -> str:
139
- return f'openai:{self._model_name}'
140
-
141
138
  async def request(
142
139
  self,
143
140
  messages: list[ModelMessage],
@@ -130,15 +130,15 @@ class TestModel(Model):
130
130
 
131
131
  def _get_result(self, model_request_parameters: ModelRequestParameters) -> _TextResult | _FunctionToolResult:
132
132
  if self.custom_result_text is not None:
133
- assert (
134
- model_request_parameters.allow_text_result
135
- ), 'Plain response not allowed, but `custom_result_text` is set.'
133
+ assert model_request_parameters.allow_text_result, (
134
+ 'Plain response not allowed, but `custom_result_text` is set.'
135
+ )
136
136
  assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
137
137
  return _TextResult(self.custom_result_text)
138
138
  elif self.custom_result_args is not None:
139
- assert (
140
- model_request_parameters.result_tools is not None
141
- ), 'No result tools provided, but `custom_result_args` is set.'
139
+ assert model_request_parameters.result_tools is not None, (
140
+ 'No result tools provided, but `custom_result_args` is set.'
141
+ )
142
142
  result_tool = model_request_parameters.result_tools[0]
143
143
 
144
144
  if k := result_tool.outer_typed_dict_key:
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import AsyncIterator
4
+ from contextlib import asynccontextmanager
5
+ from dataclasses import dataclass
6
+ from typing import Any
7
+
8
+ from ..messages import ModelMessage, ModelResponse
9
+ from ..settings import ModelSettings
10
+ from ..usage import Usage
11
+ from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
12
+
13
+
14
+ @dataclass(init=False)
15
+ class WrapperModel(Model):
16
+ """Model which wraps another model."""
17
+
18
+ wrapped: Model
19
+
20
+ def __init__(self, wrapped: Model | KnownModelName):
21
+ self.wrapped = infer_model(wrapped)
22
+
23
+ async def request(self, *args: Any, **kwargs: Any) -> tuple[ModelResponse, Usage]:
24
+ return await self.wrapped.request(*args, **kwargs)
25
+
26
+ @asynccontextmanager
27
+ async def request_stream(
28
+ self,
29
+ messages: list[ModelMessage],
30
+ model_settings: ModelSettings | None,
31
+ model_request_parameters: ModelRequestParameters,
32
+ ) -> AsyncIterator[StreamedResponse]:
33
+ async with self.wrapped.request_stream(messages, model_settings, model_request_parameters) as response_stream:
34
+ yield response_stream
35
+
36
+ @property
37
+ def model_name(self) -> str:
38
+ return self.wrapped.model_name
39
+
40
+ @property
41
+ def system(self) -> str | None:
42
+ return self.wrapped.system
43
+
44
+ def __getattr__(self, item: str):
45
+ return getattr(self.wrapped, item)