pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -4,23 +4,29 @@ from collections.abc import AsyncIterator, Iterable
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
7
+ from itertools import chain
7
8
  from typing import Literal, Union, overload
8
9
 
9
10
  from httpx import AsyncClient as AsyncHTTPClient
10
11
  from typing_extensions import assert_never
11
12
 
12
13
  from .. import UnexpectedModelBehavior, _utils, result
14
+ from .._utils import guard_tool_call_id as _guard_tool_call_id
13
15
  from ..messages import (
14
16
  ArgsJson,
15
- Message,
16
- ModelAnyResponse,
17
- ModelStructuredResponse,
18
- ModelTextResponse,
19
- RetryPrompt,
20
- ToolCall,
21
- ToolReturn,
17
+ ModelMessage,
18
+ ModelRequest,
19
+ ModelResponse,
20
+ ModelResponsePart,
21
+ RetryPromptPart,
22
+ SystemPromptPart,
23
+ TextPart,
24
+ ToolCallPart,
25
+ ToolReturnPart,
26
+ UserPromptPart,
22
27
  )
23
28
  from ..result import Cost
29
+ from ..settings import ModelSettings
24
30
  from ..tools import ToolDefinition
25
31
  from . import (
26
32
  AgentModel,
@@ -40,7 +46,7 @@ try:
40
46
  except ImportError as _import_error:
41
47
  raise ImportError(
42
48
  'Please install `openai` to use the OpenAI model, '
43
- "you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
49
+ "you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
44
50
  ) from _import_error
45
51
 
46
52
  OpenAIModelName = Union[ChatModel, str]
@@ -66,6 +72,7 @@ class OpenAIModel(Model):
66
72
  self,
67
73
  model_name: OpenAIModelName,
68
74
  *,
75
+ base_url: str | None = None,
69
76
  api_key: str | None = None,
70
77
  openai_client: AsyncOpenAI | None = None,
71
78
  http_client: AsyncHTTPClient | None = None,
@@ -76,22 +83,25 @@ class OpenAIModel(Model):
76
83
  model_name: The name of the OpenAI model to use. List of model names available
77
84
  [here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
78
85
  (Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
86
+ base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable
87
+ will be used if available. Otherwise, defaults to OpenAI's base url.
79
88
  api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
80
89
  will be used if available.
81
90
  openai_client: An existing
82
91
  [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
83
- client to use, if provided, `api_key` and `http_client` must be `None`.
92
+ client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
84
93
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
85
94
  """
86
95
  self.model_name: OpenAIModelName = model_name
87
96
  if openai_client is not None:
88
97
  assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
98
+ assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
89
99
  assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
90
100
  self.client = openai_client
91
101
  elif http_client is not None:
92
- self.client = AsyncOpenAI(api_key=api_key, http_client=http_client)
102
+ self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
93
103
  else:
94
- self.client = AsyncOpenAI(api_key=api_key, http_client=cached_async_http_client())
104
+ self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
95
105
 
96
106
  async def agent_model(
97
107
  self,
@@ -135,28 +145,34 @@ class OpenAIAgentModel(AgentModel):
135
145
  allow_text_result: bool
136
146
  tools: list[chat.ChatCompletionToolParam]
137
147
 
138
- async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
139
- response = await self._completions_create(messages, False)
148
+ async def request(
149
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
150
+ ) -> tuple[ModelResponse, result.Cost]:
151
+ response = await self._completions_create(messages, False, model_settings)
140
152
  return self._process_response(response), _map_cost(response)
141
153
 
142
154
  @asynccontextmanager
143
- async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
144
- response = await self._completions_create(messages, True)
155
+ async def request_stream(
156
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
157
+ ) -> AsyncIterator[EitherStreamedResponse]:
158
+ response = await self._completions_create(messages, True, model_settings)
145
159
  async with response:
146
160
  yield await self._process_streamed_response(response)
147
161
 
148
162
  @overload
149
163
  async def _completions_create(
150
- self, messages: list[Message], stream: Literal[True]
164
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
151
165
  ) -> AsyncStream[ChatCompletionChunk]:
152
166
  pass
153
167
 
154
168
  @overload
155
- async def _completions_create(self, messages: list[Message], stream: Literal[False]) -> chat.ChatCompletion:
169
+ async def _completions_create(
170
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
171
+ ) -> chat.ChatCompletion:
156
172
  pass
157
173
 
158
174
  async def _completions_create(
159
- self, messages: list[Message], stream: bool
175
+ self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
160
176
  ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
161
177
  # standalone function to make it easier to override
162
178
  if not self.tools:
@@ -166,7 +182,10 @@ class OpenAIAgentModel(AgentModel):
166
182
  else:
167
183
  tool_choice = 'auto'
168
184
 
169
- openai_messages = [self._map_message(m) for m in messages]
185
+ openai_messages = list(chain(*(self._map_message(m) for m in messages)))
186
+
187
+ model_settings = model_settings or {}
188
+
170
189
  return await self.client.chat.completions.create(
171
190
  model=self.model_name,
172
191
  messages=openai_messages,
@@ -176,21 +195,24 @@ class OpenAIAgentModel(AgentModel):
176
195
  tool_choice=tool_choice or NOT_GIVEN,
177
196
  stream=stream,
178
197
  stream_options={'include_usage': True} if stream else NOT_GIVEN,
198
+ max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
199
+ temperature=model_settings.get('temperature', NOT_GIVEN),
200
+ top_p=model_settings.get('top_p', NOT_GIVEN),
201
+ timeout=model_settings.get('timeout', NOT_GIVEN),
179
202
  )
180
203
 
181
204
  @staticmethod
182
- def _process_response(response: chat.ChatCompletion) -> ModelAnyResponse:
205
+ def _process_response(response: chat.ChatCompletion) -> ModelResponse:
183
206
  """Process a non-streamed response, and prepare a message to return."""
184
207
  timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
185
208
  choice = response.choices[0]
209
+ items: list[ModelResponsePart] = []
210
+ if choice.message.content is not None:
211
+ items.append(TextPart(choice.message.content))
186
212
  if choice.message.tool_calls is not None:
187
- return ModelStructuredResponse(
188
- [ToolCall.from_json(c.function.name, c.function.arguments, c.id) for c in choice.message.tool_calls],
189
- timestamp=timestamp,
190
- )
191
- else:
192
- assert choice.message.content is not None, choice
193
- return ModelTextResponse(choice.message.content, timestamp=timestamp)
213
+ for c in choice.message.tool_calls:
214
+ items.append(ToolCallPart.from_json(c.function.name, c.function.arguments, c.id))
215
+ return ModelResponse(items, timestamp=timestamp)
194
216
 
195
217
  @staticmethod
196
218
  async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
@@ -221,47 +243,57 @@ class OpenAIAgentModel(AgentModel):
221
243
  )
222
244
  # else continue until we get either delta.content or delta.tool_calls
223
245
 
224
- @staticmethod
225
- def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
246
+ @classmethod
247
+ def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
226
248
  """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
227
- if message.role == 'system':
228
- # SystemPrompt ->
229
- return chat.ChatCompletionSystemMessageParam(role='system', content=message.content)
230
- elif message.role == 'user':
231
- # UserPrompt ->
232
- return chat.ChatCompletionUserMessageParam(role='user', content=message.content)
233
- elif message.role == 'tool-return':
234
- # ToolReturn ->
235
- return chat.ChatCompletionToolMessageParam(
236
- role='tool',
237
- tool_call_id=_guard_tool_id(message),
238
- content=message.model_response_str(),
239
- )
240
- elif message.role == 'retry-prompt':
241
- # RetryPrompt ->
242
- if message.tool_name is None:
243
- return chat.ChatCompletionUserMessageParam(role='user', content=message.model_response())
244
- else:
245
- return chat.ChatCompletionToolMessageParam(
246
- role='tool',
247
- tool_call_id=_guard_tool_id(message),
248
- content=message.model_response(),
249
- )
250
- elif message.role == 'model-text-response':
251
- # ModelTextResponse ->
252
- return chat.ChatCompletionAssistantMessageParam(role='assistant', content=message.content)
253
- elif message.role == 'model-structured-response':
254
- assert (
255
- message.role == 'model-structured-response'
256
- ), f'Expected role to be "llm-tool-calls", got {message.role}'
257
- # ModelStructuredResponse ->
258
- return chat.ChatCompletionAssistantMessageParam(
259
- role='assistant',
260
- tool_calls=[_map_tool_call(t) for t in message.calls],
261
- )
249
+ if isinstance(message, ModelRequest):
250
+ yield from cls._map_user_message(message)
251
+ elif isinstance(message, ModelResponse):
252
+ texts: list[str] = []
253
+ tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
254
+ for item in message.parts:
255
+ if isinstance(item, TextPart):
256
+ texts.append(item.content)
257
+ elif isinstance(item, ToolCallPart):
258
+ tool_calls.append(_map_tool_call(item))
259
+ else:
260
+ assert_never(item)
261
+ message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
262
+ if texts:
263
+ # Note: model responses from this model should only have one text item, so the following
264
+ # shouldn't merge multiple texts into one unless you switch models between runs:
265
+ message_param['content'] = '\n\n'.join(texts)
266
+ if tool_calls:
267
+ message_param['tool_calls'] = tool_calls
268
+ yield message_param
262
269
  else:
263
270
  assert_never(message)
264
271
 
272
+ @classmethod
273
+ def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
274
+ for part in message.parts:
275
+ if isinstance(part, SystemPromptPart):
276
+ yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
277
+ elif isinstance(part, UserPromptPart):
278
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
279
+ elif isinstance(part, ToolReturnPart):
280
+ yield chat.ChatCompletionToolMessageParam(
281
+ role='tool',
282
+ tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
283
+ content=part.model_response_str(),
284
+ )
285
+ elif isinstance(part, RetryPromptPart):
286
+ if part.tool_name is None:
287
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
288
+ else:
289
+ yield chat.ChatCompletionToolMessageParam(
290
+ role='tool',
291
+ tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
292
+ content=part.model_response(),
293
+ )
294
+ else:
295
+ assert_never(part)
296
+
265
297
 
266
298
  @dataclass
267
299
  class OpenAIStreamTextResponse(StreamTextResponse):
@@ -335,14 +367,14 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
335
367
  else:
336
368
  self._delta_tool_calls[new.index] = new
337
369
 
338
- def get(self, *, final: bool = False) -> ModelStructuredResponse:
339
- calls: list[ToolCall] = []
370
+ def get(self, *, final: bool = False) -> ModelResponse:
371
+ items: list[ModelResponsePart] = []
340
372
  for c in self._delta_tool_calls.values():
341
373
  if f := c.function:
342
374
  if f.name is not None and f.arguments is not None:
343
- calls.append(ToolCall.from_json(f.name, f.arguments, c.id))
375
+ items.append(ToolCallPart.from_json(f.name, f.arguments, c.id))
344
376
 
345
- return ModelStructuredResponse(calls, timestamp=self._timestamp)
377
+ return ModelResponse(items, timestamp=self._timestamp)
346
378
 
347
379
  def cost(self) -> Cost:
348
380
  return self._cost
@@ -351,16 +383,10 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
351
383
  return self._timestamp
352
384
 
353
385
 
354
- def _guard_tool_id(t: ToolCall | ToolReturn | RetryPrompt) -> str:
355
- """Type guard that checks a `tool_id` is not None both for static typing and runtime."""
356
- assert t.tool_id is not None, f'OpenAI requires `tool_id` to be set: {t}'
357
- return t.tool_id
358
-
359
-
360
- def _map_tool_call(t: ToolCall) -> chat.ChatCompletionMessageToolCallParam:
386
+ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
361
387
  assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
362
388
  return chat.ChatCompletionMessageToolCallParam(
363
- id=_guard_tool_id(t),
389
+ id=_guard_tool_call_id(t=t, model_source='OpenAI'),
364
390
  type='function',
365
391
  function={'name': t.tool_name, 'arguments': t.args.args_json},
366
392
  )
@@ -9,18 +9,20 @@ from datetime import date, datetime, timedelta
9
9
  from typing import Any, Literal
10
10
 
11
11
  import pydantic_core
12
+ from typing_extensions import assert_never
12
13
 
13
14
  from .. import _utils
14
15
  from ..messages import (
15
- Message,
16
- ModelAnyResponse,
17
- ModelStructuredResponse,
18
- ModelTextResponse,
19
- RetryPrompt,
20
- ToolCall,
21
- ToolReturn,
16
+ ModelMessage,
17
+ ModelRequest,
18
+ ModelResponse,
19
+ RetryPromptPart,
20
+ TextPart,
21
+ ToolCallPart,
22
+ ToolReturnPart,
22
23
  )
23
24
  from ..result import Cost
25
+ from ..settings import ModelSettings
24
26
  from ..tools import ToolDefinition
25
27
  from . import (
26
28
  AgentModel,
@@ -127,74 +129,83 @@ class TestAgentModel(AgentModel):
127
129
  result_tools: list[ToolDefinition]
128
130
  seed: int
129
131
 
130
- async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, Cost]:
131
- return self._request(messages), Cost()
132
+ async def request(
133
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
134
+ ) -> tuple[ModelResponse, Cost]:
135
+ return self._request(messages, model_settings), Cost()
132
136
 
133
137
  @asynccontextmanager
134
- async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
135
- msg = self._request(messages)
138
+ async def request_stream(
139
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
140
+ ) -> AsyncIterator[EitherStreamedResponse]:
141
+ msg = self._request(messages, model_settings)
136
142
  cost = Cost()
137
- if isinstance(msg, ModelTextResponse):
138
- yield TestStreamTextResponse(msg.content, cost)
143
+
144
+ # TODO: Rework this once we make StreamTextResponse more general
145
+ texts: list[str] = []
146
+ tool_calls: list[ToolCallPart] = []
147
+ for item in msg.parts:
148
+ if isinstance(item, TextPart):
149
+ texts.append(item.content)
150
+ elif isinstance(item, ToolCallPart):
151
+ tool_calls.append(item)
152
+ else:
153
+ assert_never(item)
154
+
155
+ if texts:
156
+ yield TestStreamTextResponse('\n\n'.join(texts), cost)
139
157
  else:
140
158
  yield TestStreamStructuredResponse(msg, cost)
141
159
 
142
160
  def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
143
161
  return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
144
162
 
145
- def _request(self, messages: list[Message]) -> ModelAnyResponse:
163
+ def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | None) -> ModelResponse:
146
164
  # if there are tools, the first thing we want to do is call all of them
147
- if self.tool_calls and not any(m.role == 'model-structured-response' for m in messages):
148
- calls = [ToolCall.from_dict(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
149
- return ModelStructuredResponse(calls=calls)
150
-
151
- # get messages since the last model response
152
- new_messages = _get_new_messages(messages)
153
-
154
- # check if there are any retry prompts, if so retry them
155
- new_retry_names = {m.tool_name for m in new_messages if isinstance(m, RetryPrompt)}
156
- if new_retry_names:
157
- calls = [
158
- ToolCall.from_dict(name, self.gen_tool_args(args))
159
- for name, args in self.tool_calls
160
- if name in new_retry_names
161
- ]
162
- return ModelStructuredResponse(calls=calls)
165
+ if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
166
+ return ModelResponse(
167
+ parts=[ToolCallPart.from_dict(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
168
+ )
169
+
170
+ if messages:
171
+ last_message = messages[-1]
172
+ assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.'
173
+
174
+ # check if there are any retry prompts, if so retry them
175
+ new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
176
+ if new_retry_names:
177
+ return ModelResponse(
178
+ parts=[
179
+ ToolCallPart.from_dict(name, self.gen_tool_args(args))
180
+ for name, args in self.tool_calls
181
+ if name in new_retry_names
182
+ ]
183
+ )
163
184
 
164
185
  if response_text := self.result.left:
165
186
  if response_text.value is None:
166
187
  # build up details of tool responses
167
188
  output: dict[str, Any] = {}
168
189
  for message in messages:
169
- if isinstance(message, ToolReturn):
170
- output[message.tool_name] = message.content
190
+ if isinstance(message, ModelRequest):
191
+ for part in message.parts:
192
+ if isinstance(part, ToolReturnPart):
193
+ output[part.tool_name] = part.content
171
194
  if output:
172
- return ModelTextResponse(content=pydantic_core.to_json(output).decode())
195
+ return ModelResponse.from_text(pydantic_core.to_json(output).decode())
173
196
  else:
174
- return ModelTextResponse(content='success (no tool calls)')
197
+ return ModelResponse.from_text('success (no tool calls)')
175
198
  else:
176
- return ModelTextResponse(content=response_text.value)
199
+ return ModelResponse.from_text(response_text.value)
177
200
  else:
178
201
  assert self.result_tools, 'No result tools provided'
179
202
  custom_result_args = self.result.right
180
203
  result_tool = self.result_tools[self.seed % len(self.result_tools)]
181
204
  if custom_result_args is not None:
182
- return ModelStructuredResponse(calls=[ToolCall.from_dict(result_tool.name, custom_result_args)])
205
+ return ModelResponse(parts=[ToolCallPart.from_dict(result_tool.name, custom_result_args)])
183
206
  else:
184
207
  response_args = self.gen_tool_args(result_tool)
185
- return ModelStructuredResponse(calls=[ToolCall.from_dict(result_tool.name, response_args)])
186
-
187
-
188
- def _get_new_messages(messages: list[Message]) -> list[Message]:
189
- last_model_index = None
190
- for i, m in enumerate(messages):
191
- if m.role in ('model-structured-response', 'model-text-response'):
192
- last_model_index = i
193
-
194
- if last_model_index is not None:
195
- return messages[last_model_index + 1 :]
196
- else:
197
- return []
208
+ return ModelResponse(parts=[ToolCallPart.from_dict(result_tool.name, response_args)])
198
209
 
199
210
 
200
211
  @dataclass
@@ -234,7 +245,7 @@ class TestStreamTextResponse(StreamTextResponse):
234
245
  class TestStreamStructuredResponse(StreamStructuredResponse):
235
246
  """A structured response that streams test data."""
236
247
 
237
- _structured_response: ModelStructuredResponse
248
+ _structured_response: ModelResponse
238
249
  _cost: Cost
239
250
  _iter: Iterator[None] = field(default_factory=lambda: iter([None]))
240
251
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
@@ -242,7 +253,7 @@ class TestStreamStructuredResponse(StreamStructuredResponse):
242
253
  async def __anext__(self) -> None:
243
254
  return _utils.sync_anext(self._iter)
244
255
 
245
- def get(self, *, final: bool = False) -> ModelStructuredResponse:
256
+ def get(self, *, final: bool = False) -> ModelResponse:
246
257
  return self._structured_response
247
258
 
248
259
  def cost(self) -> Cost:
@@ -21,7 +21,7 @@ try:
21
21
  except ImportError as _import_error:
22
22
  raise ImportError(
23
23
  'Please install `google-auth` to use the VertexAI model, '
24
- "you can use the `vertexai` optional group — `pip install 'pydantic-ai[vertexai]'`"
24
+ "you can use the `vertexai` optional group — `pip install 'pydantic-ai-slim[vertexai]'`"
25
25
  ) from _import_error
26
26
 
27
27
  VERTEX_AI_URL_TEMPLATE = (
@@ -114,7 +114,7 @@ class VertexAIModel(Model):
114
114
  allow_text_result: bool,
115
115
  result_tools: list[ToolDefinition],
116
116
  ) -> GeminiAgentModel:
117
- url, auth = await self._ainit()
117
+ url, auth = await self.ainit()
118
118
  return GeminiAgentModel(
119
119
  http_client=self.http_client,
120
120
  model_name=self.model_name,
@@ -125,7 +125,11 @@ class VertexAIModel(Model):
125
125
  result_tools=result_tools,
126
126
  )
127
127
 
128
- async def _ainit(self) -> tuple[str, BearerTokenAuth]:
128
+ async def ainit(self) -> tuple[str, BearerTokenAuth]:
129
+ """Initialize the model, setting the URL and auth.
130
+
131
+ This will raise an error if authentication fails.
132
+ """
129
133
  if self.url is not None and self.auth is not None:
130
134
  return self.url, self.auth
131
135