pydantic-ai-slim 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,225 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import AsyncIterator, Iterator
4
+ from contextlib import asynccontextmanager, contextmanager
5
+ from dataclasses import dataclass, field
6
+ from functools import partial
7
+ from typing import Any, Callable, Literal
8
+
9
+ import logfire_api
10
+ from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
11
+ from opentelemetry.trace import Tracer, TracerProvider, get_tracer_provider
12
+
13
+ from ..messages import (
14
+ ModelMessage,
15
+ ModelRequest,
16
+ ModelRequestPart,
17
+ ModelResponse,
18
+ RetryPromptPart,
19
+ SystemPromptPart,
20
+ TextPart,
21
+ ToolCallPart,
22
+ ToolReturnPart,
23
+ UserPromptPart,
24
+ )
25
+ from ..settings import ModelSettings
26
+ from ..usage import Usage
27
+ from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
28
+ from .wrapper import WrapperModel
29
+
30
+ MODEL_SETTING_ATTRIBUTES: tuple[
31
+ Literal[
32
+ 'max_tokens',
33
+ 'top_p',
34
+ 'seed',
35
+ 'temperature',
36
+ 'presence_penalty',
37
+ 'frequency_penalty',
38
+ ],
39
+ ...,
40
+ ] = (
41
+ 'max_tokens',
42
+ 'top_p',
43
+ 'seed',
44
+ 'temperature',
45
+ 'presence_penalty',
46
+ 'frequency_penalty',
47
+ )
48
+
49
+ NOT_GIVEN = object()
50
+
51
+
52
+ @dataclass
53
+ class InstrumentedModel(WrapperModel):
54
+ """Model which is instrumented with logfire."""
55
+
56
+ tracer: Tracer = field(repr=False)
57
+ event_logger: EventLogger = field(repr=False)
58
+
59
+ def __init__(
60
+ self,
61
+ wrapped: Model | KnownModelName,
62
+ tracer_provider: TracerProvider | None = None,
63
+ event_logger_provider: EventLoggerProvider | None = None,
64
+ ):
65
+ super().__init__(wrapped)
66
+ tracer_provider = tracer_provider or get_tracer_provider()
67
+ event_logger_provider = event_logger_provider or get_event_logger_provider()
68
+ self.tracer = tracer_provider.get_tracer('pydantic-ai')
69
+ self.event_logger = event_logger_provider.get_event_logger('pydantic-ai')
70
+
71
+ @classmethod
72
+ def from_logfire(
73
+ cls,
74
+ wrapped: Model | KnownModelName,
75
+ logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE,
76
+ ) -> InstrumentedModel:
77
+ if hasattr(logfire_instance.config, 'get_event_logger_provider'):
78
+ event_provider = logfire_instance.config.get_event_logger_provider()
79
+ else:
80
+ event_provider = None
81
+ tracer_provider = logfire_instance.config.get_tracer_provider()
82
+ return cls(wrapped, tracer_provider, event_provider)
83
+
84
+ async def request(
85
+ self,
86
+ messages: list[ModelMessage],
87
+ model_settings: ModelSettings | None,
88
+ model_request_parameters: ModelRequestParameters,
89
+ ) -> tuple[ModelResponse, Usage]:
90
+ with self._instrument(messages, model_settings) as finish:
91
+ response, usage = await super().request(messages, model_settings, model_request_parameters)
92
+ finish(response, usage)
93
+ return response, usage
94
+
95
+ @asynccontextmanager
96
+ async def request_stream(
97
+ self,
98
+ messages: list[ModelMessage],
99
+ model_settings: ModelSettings | None,
100
+ model_request_parameters: ModelRequestParameters,
101
+ ) -> AsyncIterator[StreamedResponse]:
102
+ with self._instrument(messages, model_settings) as finish:
103
+ response_stream: StreamedResponse | None = None
104
+ try:
105
+ async with super().request_stream(
106
+ messages, model_settings, model_request_parameters
107
+ ) as response_stream:
108
+ yield response_stream
109
+ finally:
110
+ if response_stream:
111
+ finish(response_stream.get(), response_stream.usage())
112
+
113
+ @contextmanager
114
+ def _instrument(
115
+ self,
116
+ messages: list[ModelMessage],
117
+ model_settings: ModelSettings | None,
118
+ ) -> Iterator[Callable[[ModelResponse, Usage], None]]:
119
+ operation = 'chat'
120
+ model_name = self.model_name
121
+ span_name = f'{operation} {model_name}'
122
+ system = getattr(self.wrapped, 'system', '') or self.wrapped.__class__.__name__.removesuffix('Model').lower()
123
+ system = {'google-gla': 'gemini', 'google-vertex': 'vertex_ai', 'mistral': 'mistral_ai'}.get(system, system)
124
+ # TODO Missing attributes:
125
+ # - server.address: requires a Model.base_url abstract method or similar
126
+ # - server.port: to parse from the base_url
127
+ # - error.type: unclear if we should do something here or just always rely on span exceptions
128
+ # - gen_ai.request.stop_sequences/top_k: model_settings doesn't include these
129
+ attributes: dict[str, Any] = {
130
+ 'gen_ai.operation.name': operation,
131
+ 'gen_ai.system': system,
132
+ 'gen_ai.request.model': model_name,
133
+ }
134
+
135
+ if model_settings:
136
+ for key in MODEL_SETTING_ATTRIBUTES:
137
+ if (value := model_settings.get(key, NOT_GIVEN)) is not NOT_GIVEN:
138
+ attributes[f'gen_ai.request.{key}'] = value
139
+
140
+ emit_event = partial(self._emit_event, system)
141
+
142
+ with self.tracer.start_as_current_span(span_name, attributes=attributes) as span:
143
+ if span.is_recording():
144
+ for message in messages:
145
+ if isinstance(message, ModelRequest):
146
+ for part in message.parts:
147
+ event_name, body = _request_part_body(part)
148
+ if event_name:
149
+ emit_event(event_name, body)
150
+ elif isinstance(message, ModelResponse):
151
+ for body in _response_bodies(message):
152
+ emit_event('gen_ai.assistant.message', body)
153
+
154
+ def finish(response: ModelResponse, usage: Usage):
155
+ if not span.is_recording():
156
+ return
157
+
158
+ for response_body in _response_bodies(response):
159
+ if response_body:
160
+ emit_event(
161
+ 'gen_ai.choice',
162
+ {
163
+ # TODO finish_reason
164
+ 'index': 0,
165
+ 'message': response_body,
166
+ },
167
+ )
168
+ span.set_attributes(
169
+ {
170
+ k: v
171
+ for k, v in {
172
+ # TODO finish_reason (https://github.com/open-telemetry/semantic-conventions/issues/1277), id
173
+ # https://github.com/pydantic/pydantic-ai/issues/886
174
+ 'gen_ai.response.model': response.model_name or model_name,
175
+ 'gen_ai.usage.input_tokens': usage.request_tokens,
176
+ 'gen_ai.usage.output_tokens': usage.response_tokens,
177
+ }.items()
178
+ if v is not None
179
+ }
180
+ )
181
+
182
+ yield finish
183
+
184
+ def _emit_event(self, system: str, event_name: str, body: dict[str, Any]) -> None:
185
+ self.event_logger.emit(Event(event_name, body=body, attributes={'gen_ai.system': system}))
186
+
187
+
188
+ def _request_part_body(part: ModelRequestPart) -> tuple[str, dict[str, Any]]:
189
+ if isinstance(part, SystemPromptPart):
190
+ return 'gen_ai.system.message', {'content': part.content, 'role': 'system'}
191
+ elif isinstance(part, UserPromptPart):
192
+ return 'gen_ai.user.message', {'content': part.content, 'role': 'user'}
193
+ elif isinstance(part, ToolReturnPart):
194
+ return 'gen_ai.tool.message', {'content': part.content, 'role': 'tool', 'id': part.tool_call_id}
195
+ elif isinstance(part, RetryPromptPart):
196
+ if part.tool_name is None:
197
+ return 'gen_ai.user.message', {'content': part.model_response(), 'role': 'user'}
198
+ else:
199
+ return 'gen_ai.tool.message', {'content': part.model_response(), 'role': 'tool', 'id': part.tool_call_id}
200
+ else:
201
+ return '', {}
202
+
203
+
204
+ def _response_bodies(message: ModelResponse) -> list[dict[str, Any]]:
205
+ body: dict[str, Any] = {'role': 'assistant'}
206
+ result = [body]
207
+ for part in message.parts:
208
+ if isinstance(part, ToolCallPart):
209
+ body.setdefault('tool_calls', []).append(
210
+ {
211
+ 'id': part.tool_call_id,
212
+ 'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888
213
+ 'function': {
214
+ 'name': part.tool_name,
215
+ 'arguments': part.args,
216
+ },
217
+ }
218
+ )
219
+ elif isinstance(part, TextPart):
220
+ if body.get('content'):
221
+ body = {'role': 'assistant'}
222
+ result.append(body)
223
+ body['content'] = part.content
224
+
225
+ return result
@@ -134,9 +134,6 @@ class MistralModel(Model):
134
134
  api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
135
135
  self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
136
136
 
137
- def name(self) -> str:
138
- return f'mistral:{self._model_name}'
139
-
140
137
  async def request(
141
138
  self,
142
139
  messages: list[ModelMessage],
@@ -165,6 +162,16 @@ class MistralModel(Model):
165
162
  async with response:
166
163
  yield await self._process_streamed_response(model_request_parameters.result_tools, response)
167
164
 
165
+ @property
166
+ def model_name(self) -> MistralModelName:
167
+ """The model name."""
168
+ return self._model_name
169
+
170
+ @property
171
+ def system(self) -> str | None:
172
+ """The system / model provider."""
173
+ return self._system
174
+
168
175
  async def _completions_create(
169
176
  self,
170
177
  messages: list[ModelMessage],
@@ -296,7 +303,7 @@ class MistralModel(Model):
296
303
  tool = self._map_mistral_to_pydantic_tool_call(tool_call=tool_call)
297
304
  parts.append(tool)
298
305
 
299
- return ModelResponse(parts, model_name=self._model_name, timestamp=timestamp)
306
+ return ModelResponse(parts, model_name=response.model, timestamp=timestamp)
300
307
 
301
308
  async def _process_streamed_response(
302
309
  self,
@@ -461,6 +468,7 @@ MistralToolCallId = Union[str, None]
461
468
  class MistralStreamedResponse(StreamedResponse):
462
469
  """Implementation of `StreamedResponse` for Mistral models."""
463
470
 
471
+ _model_name: MistralModelName
464
472
  _response: AsyncIterable[MistralCompletionEvent]
465
473
  _timestamp: datetime
466
474
  _result_tools: dict[str, ToolDefinition]
@@ -502,7 +510,14 @@ class MistralStreamedResponse(StreamedResponse):
502
510
  vendor_part_id=index, tool_name=dtc.function.name, args=dtc.function.arguments, tool_call_id=dtc.id
503
511
  )
504
512
 
513
+ @property
514
+ def model_name(self) -> MistralModelName:
515
+ """Get the model name of the response."""
516
+ return self._model_name
517
+
518
+ @property
505
519
  def timestamp(self) -> datetime:
520
+ """Get the timestamp of the response."""
506
521
  return self._timestamp
507
522
 
508
523
  @staticmethod
@@ -119,10 +119,11 @@ class OpenAIModel(Model):
119
119
  """
120
120
  self._model_name = model_name
121
121
  # This is a workaround for the OpenAI client requiring an API key, whilst locally served,
122
- # openai compatible models do not always need an API key.
122
+ # openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
123
123
  if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None:
124
- api_key = ''
125
- elif openai_client is not None:
124
+ api_key = 'api-key-not-set'
125
+
126
+ if openai_client is not None:
126
127
  assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
127
128
  assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
128
129
  assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
@@ -134,9 +135,6 @@ class OpenAIModel(Model):
134
135
  self.system_prompt_role = system_prompt_role
135
136
  self._system = system
136
137
 
137
- def name(self) -> str:
138
- return f'openai:{self._model_name}'
139
-
140
138
  async def request(
141
139
  self,
142
140
  messages: list[ModelMessage],
@@ -163,6 +161,16 @@ class OpenAIModel(Model):
163
161
  async with response:
164
162
  yield await self._process_streamed_response(response)
165
163
 
164
+ @property
165
+ def model_name(self) -> OpenAIModelName:
166
+ """The model name."""
167
+ return self._model_name
168
+
169
+ @property
170
+ def system(self) -> str | None:
171
+ """The system / model provider."""
172
+ return self._system
173
+
166
174
  @overload
167
175
  async def _completions_create(
168
176
  self,
@@ -232,7 +240,7 @@ class OpenAIModel(Model):
232
240
  if choice.message.tool_calls is not None:
233
241
  for c in choice.message.tool_calls:
234
242
  items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
235
- return ModelResponse(items, model_name=self._model_name, timestamp=timestamp)
243
+ return ModelResponse(items, model_name=response.model, timestamp=timestamp)
236
244
 
237
245
  async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
238
246
  """Process a streamed response, and prepare a streaming response to return."""
@@ -331,6 +339,7 @@ class OpenAIModel(Model):
331
339
  class OpenAIStreamedResponse(StreamedResponse):
332
340
  """Implementation of `StreamedResponse` for OpenAI models."""
333
341
 
342
+ _model_name: OpenAIModelName
334
343
  _response: AsyncIterable[ChatCompletionChunk]
335
344
  _timestamp: datetime
336
345
 
@@ -358,7 +367,14 @@ class OpenAIStreamedResponse(StreamedResponse):
358
367
  if maybe_event is not None:
359
368
  yield maybe_event
360
369
 
370
+ @property
371
+ def model_name(self) -> OpenAIModelName:
372
+ """Get the model name of the response."""
373
+ return self._model_name
374
+
375
+ @property
361
376
  def timestamp(self) -> datetime:
377
+ """Get the timestamp of the response."""
362
378
  return self._timestamp
363
379
 
364
380
 
@@ -107,6 +107,16 @@ class TestModel(Model):
107
107
  _model_name=self._model_name, _structured_response=model_response, _messages=messages
108
108
  )
109
109
 
110
+ @property
111
+ def model_name(self) -> str:
112
+ """The model name."""
113
+ return self._model_name
114
+
115
+ @property
116
+ def system(self) -> str | None:
117
+ """The system / model provider."""
118
+ return self._system
119
+
110
120
  def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
111
121
  return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
112
122
 
@@ -120,15 +130,15 @@ class TestModel(Model):
120
130
 
121
131
  def _get_result(self, model_request_parameters: ModelRequestParameters) -> _TextResult | _FunctionToolResult:
122
132
  if self.custom_result_text is not None:
123
- assert (
124
- model_request_parameters.allow_text_result
125
- ), 'Plain response not allowed, but `custom_result_text` is set.'
133
+ assert model_request_parameters.allow_text_result, (
134
+ 'Plain response not allowed, but `custom_result_text` is set.'
135
+ )
126
136
  assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
127
137
  return _TextResult(self.custom_result_text)
128
138
  elif self.custom_result_args is not None:
129
- assert (
130
- model_request_parameters.result_tools is not None
131
- ), 'No result tools provided, but `custom_result_args` is set.'
139
+ assert model_request_parameters.result_tools is not None, (
140
+ 'No result tools provided, but `custom_result_args` is set.'
141
+ )
132
142
  result_tool = model_request_parameters.result_tools[0]
133
143
 
134
144
  if k := result_tool.outer_typed_dict_key:
@@ -221,9 +231,9 @@ class TestModel(Model):
221
231
  class TestStreamedResponse(StreamedResponse):
222
232
  """A structured response that streams test data."""
223
233
 
234
+ _model_name: str
224
235
  _structured_response: ModelResponse
225
236
  _messages: InitVar[Iterable[ModelMessage]]
226
-
227
237
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
228
238
 
229
239
  def __post_init__(self, _messages: Iterable[ModelMessage]):
@@ -249,7 +259,14 @@ class TestStreamedResponse(StreamedResponse):
249
259
  vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
250
260
  )
251
261
 
262
+ @property
263
+ def model_name(self) -> str:
264
+ """Get the model name of the response."""
265
+ return self._model_name
266
+
267
+ @property
252
268
  def timestamp(self) -> datetime:
269
+ """Get the timestamp of the response."""
253
270
  return self._timestamp
254
271
 
255
272
 
@@ -161,6 +161,16 @@ class VertexAIModel(GeminiModel):
161
161
  async with super().request_stream(messages, model_settings, model_request_parameters) as value:
162
162
  yield value
163
163
 
164
+ @property
165
+ def model_name(self) -> GeminiModelName:
166
+ """The model name."""
167
+ return self._model_name
168
+
169
+ @property
170
+ def system(self) -> str | None:
171
+ """The system / model provider."""
172
+ return self._system
173
+
164
174
 
165
175
  # pyright: reportUnknownMemberType=false
166
176
  def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCredentials:
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import AsyncIterator
4
+ from contextlib import asynccontextmanager
5
+ from dataclasses import dataclass
6
+ from typing import Any
7
+
8
+ from ..messages import ModelMessage, ModelResponse
9
+ from ..settings import ModelSettings
10
+ from ..usage import Usage
11
+ from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
12
+
13
+
14
+ @dataclass(init=False)
15
+ class WrapperModel(Model):
16
+ """Model which wraps another model."""
17
+
18
+ wrapped: Model
19
+
20
+ def __init__(self, wrapped: Model | KnownModelName):
21
+ self.wrapped = infer_model(wrapped)
22
+
23
+ async def request(self, *args: Any, **kwargs: Any) -> tuple[ModelResponse, Usage]:
24
+ return await self.wrapped.request(*args, **kwargs)
25
+
26
+ @asynccontextmanager
27
+ async def request_stream(
28
+ self,
29
+ messages: list[ModelMessage],
30
+ model_settings: ModelSettings | None,
31
+ model_request_parameters: ModelRequestParameters,
32
+ ) -> AsyncIterator[StreamedResponse]:
33
+ async with self.wrapped.request_stream(messages, model_settings, model_request_parameters) as response_stream:
34
+ yield response_stream
35
+
36
+ @property
37
+ def model_name(self) -> str:
38
+ return self.wrapped.model_name
39
+
40
+ @property
41
+ def system(self) -> str | None:
42
+ return self.wrapped.system
43
+
44
+ def __getattr__(self, item: str):
45
+ return getattr(self.wrapped, item)