pydantic-ai-slim 0.0.20__py3-none-any.whl → 0.0.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
7
7
  from itertools import chain
8
- from typing import Literal, overload
8
+ from typing import Literal, cast, overload
9
9
 
10
10
  from httpx import AsyncClient as AsyncHTTPClient
11
11
  from typing_extensions import assert_never
@@ -47,10 +47,7 @@ except ImportError as _import_error:
47
47
 
48
48
  GroqModelName = Literal[
49
49
  'llama-3.3-70b-versatile',
50
- 'llama-3.1-70b-versatile',
51
- 'llama3-groq-70b-8192-tool-use-preview',
52
- 'llama3-groq-8b-8192-tool-use-preview',
53
- 'llama-3.1-70b-specdec',
50
+ 'llama-3.3-70b-specdec',
54
51
  'llama-3.1-8b-instant',
55
52
  'llama-3.2-1b-preview',
56
53
  'llama-3.2-3b-preview',
@@ -60,7 +57,6 @@ GroqModelName = Literal[
60
57
  'llama3-8b-8192',
61
58
  'mixtral-8x7b-32768',
62
59
  'gemma2-9b-it',
63
- 'gemma-7b-it',
64
60
  ]
65
61
  """Named Groq models.
66
62
 
@@ -68,6 +64,12 @@ See [the Groq docs](https://console.groq.com/docs/models) for a full list.
68
64
  """
69
65
 
70
66
 
67
+ class GroqModelSettings(ModelSettings):
68
+ """Settings used for a Groq model request."""
69
+
70
+ # This class is a placeholder for any future groq-specific settings
71
+
72
+
71
73
  @dataclass(init=False)
72
74
  class GroqModel(Model):
73
75
  """A model that uses the Groq API.
@@ -155,31 +157,31 @@ class GroqAgentModel(AgentModel):
155
157
  async def request(
156
158
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
157
159
  ) -> tuple[ModelResponse, usage.Usage]:
158
- response = await self._completions_create(messages, False, model_settings)
160
+ response = await self._completions_create(messages, False, cast(GroqModelSettings, model_settings or {}))
159
161
  return self._process_response(response), _map_usage(response)
160
162
 
161
163
  @asynccontextmanager
162
164
  async def request_stream(
163
165
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
164
166
  ) -> AsyncIterator[StreamedResponse]:
165
- response = await self._completions_create(messages, True, model_settings)
167
+ response = await self._completions_create(messages, True, cast(GroqModelSettings, model_settings or {}))
166
168
  async with response:
167
169
  yield await self._process_streamed_response(response)
168
170
 
169
171
  @overload
170
172
  async def _completions_create(
171
- self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
173
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: GroqModelSettings
172
174
  ) -> AsyncStream[ChatCompletionChunk]:
173
175
  pass
174
176
 
175
177
  @overload
176
178
  async def _completions_create(
177
- self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
179
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: GroqModelSettings
178
180
  ) -> chat.ChatCompletion:
179
181
  pass
180
182
 
181
183
  async def _completions_create(
182
- self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
184
+ self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings
183
185
  ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
184
186
  # standalone function to make it easier to override
185
187
  if not self.tools:
@@ -191,13 +193,11 @@ class GroqAgentModel(AgentModel):
191
193
 
192
194
  groq_messages = list(chain(*(self._map_message(m) for m in messages)))
193
195
 
194
- model_settings = model_settings or {}
195
-
196
196
  return await self.client.chat.completions.create(
197
197
  model=str(self.model_name),
198
198
  messages=groq_messages,
199
199
  n=1,
200
- parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
200
+ parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
201
201
  tools=self.tools or NOT_GIVEN,
202
202
  tool_choice=tool_choice or NOT_GIVEN,
203
203
  stream=stream,
@@ -205,6 +205,10 @@ class GroqAgentModel(AgentModel):
205
205
  temperature=model_settings.get('temperature', NOT_GIVEN),
206
206
  top_p=model_settings.get('top_p', NOT_GIVEN),
207
207
  timeout=model_settings.get('timeout', NOT_GIVEN),
208
+ seed=model_settings.get('seed', NOT_GIVEN),
209
+ presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
210
+ frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
211
+ logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
208
212
  )
209
213
 
210
214
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
@@ -216,9 +220,7 @@ class GroqAgentModel(AgentModel):
216
220
  items.append(TextPart(content=choice.message.content))
217
221
  if choice.message.tool_calls is not None:
218
222
  for c in choice.message.tool_calls:
219
- items.append(
220
- ToolCallPart.from_raw_args(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)
221
- )
223
+ items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
222
224
  return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
223
225
 
224
226
  async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
@@ -6,7 +6,7 @@ from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime, timezone
8
8
  from itertools import chain
9
- from typing import Any, Callable, Literal, Union
9
+ from typing import Any, Callable, Literal, Union, cast
10
10
 
11
11
  import pydantic_core
12
12
  from httpx import AsyncClient as AsyncHTTPClient, Timeout
@@ -15,7 +15,6 @@ from typing_extensions import assert_never
15
15
  from .. import UnexpectedModelBehavior, _utils
16
16
  from .._utils import now_utc as _now_utc
17
17
  from ..messages import (
18
- ArgsJson,
19
18
  ModelMessage,
20
19
  ModelRequest,
21
20
  ModelResponse,
@@ -85,6 +84,12 @@ Since [the Mistral docs](https://docs.mistral.ai/getting-started/models/models_o
85
84
  """
86
85
 
87
86
 
87
+ class MistralModelSettings(ModelSettings):
88
+ """Settings used for a Mistral model request."""
89
+
90
+ # This class is a placeholder for any future mistral-specific settings
91
+
92
+
88
93
  @dataclass(init=False)
89
94
  class MistralModel(Model):
90
95
  """A model that uses Mistral.
@@ -159,7 +164,7 @@ class MistralAgentModel(AgentModel):
159
164
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
160
165
  ) -> tuple[ModelResponse, Usage]:
161
166
  """Make a non-streaming request to the model from Pydantic AI call."""
162
- response = await self._completions_create(messages, model_settings)
167
+ response = await self._completions_create(messages, cast(MistralModelSettings, model_settings or {}))
163
168
  return self._process_response(response), _map_usage(response)
164
169
 
165
170
  @asynccontextmanager
@@ -167,15 +172,14 @@ class MistralAgentModel(AgentModel):
167
172
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
168
173
  ) -> AsyncIterator[StreamedResponse]:
169
174
  """Make a streaming request to the model from Pydantic AI call."""
170
- response = await self._stream_completions_create(messages, model_settings)
175
+ response = await self._stream_completions_create(messages, cast(MistralModelSettings, model_settings or {}))
171
176
  async with response:
172
177
  yield await self._process_streamed_response(self.result_tools, response)
173
178
 
174
179
  async def _completions_create(
175
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
180
+ self, messages: list[ModelMessage], model_settings: MistralModelSettings
176
181
  ) -> MistralChatCompletionResponse:
177
182
  """Make a non-streaming request to the model."""
178
- model_settings = model_settings or {}
179
183
  response = await self.client.chat.complete_async(
180
184
  model=str(self.model_name),
181
185
  messages=list(chain(*(self._map_message(m) for m in messages))),
@@ -187,6 +191,7 @@ class MistralAgentModel(AgentModel):
187
191
  temperature=model_settings.get('temperature', UNSET),
188
192
  top_p=model_settings.get('top_p', 1),
189
193
  timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
194
+ random_seed=model_settings.get('seed', UNSET),
190
195
  )
191
196
  assert response, 'A unexpected empty response from Mistral.'
192
197
  return response
@@ -194,12 +199,11 @@ class MistralAgentModel(AgentModel):
194
199
  async def _stream_completions_create(
195
200
  self,
196
201
  messages: list[ModelMessage],
197
- model_settings: ModelSettings | None,
202
+ model_settings: MistralModelSettings,
198
203
  ) -> MistralEventStreamAsync[MistralCompletionEvent]:
199
204
  """Create a streaming completion request to the Mistral model."""
200
205
  response: MistralEventStreamAsync[MistralCompletionEvent] | None
201
206
  mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
202
- model_settings = model_settings or {}
203
207
 
204
208
  if self.result_tools and self.function_tools or self.function_tools:
205
209
  # Function Calling
@@ -213,6 +217,8 @@ class MistralAgentModel(AgentModel):
213
217
  top_p=model_settings.get('top_p', 1),
214
218
  max_tokens=model_settings.get('max_tokens', UNSET),
215
219
  timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
220
+ presence_penalty=model_settings.get('presence_penalty'),
221
+ frequency_penalty=model_settings.get('frequency_penalty'),
216
222
  )
217
223
 
218
224
  elif self.result_tools:
@@ -317,18 +323,11 @@ class MistralAgentModel(AgentModel):
317
323
  @staticmethod
318
324
  def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
319
325
  """Maps a pydantic-ai ToolCall to a MistralToolCall."""
320
- if isinstance(t.args, ArgsJson):
321
- return MistralToolCall(
322
- id=t.tool_call_id,
323
- type='function',
324
- function=MistralFunctionCall(name=t.tool_name, arguments=t.args.args_json),
325
- )
326
- else:
327
- return MistralToolCall(
328
- id=t.tool_call_id,
329
- type='function',
330
- function=MistralFunctionCall(name=t.tool_name, arguments=t.args.args_dict),
331
- )
326
+ return MistralToolCall(
327
+ id=t.tool_call_id,
328
+ type='function',
329
+ function=MistralFunctionCall(name=t.tool_name, arguments=t.args),
330
+ )
332
331
 
333
332
  def _generate_user_output_format(self, schemas: list[dict[str, Any]]) -> MistralUserMessage:
334
333
  """Get a message with an example of the expected output format."""
@@ -511,7 +510,7 @@ class MistralStreamedResponse(StreamedResponse):
511
510
  continue
512
511
 
513
512
  # The following part_id will be thrown away
514
- return ToolCallPart.from_raw_args(tool_name=result_tool.name, args=output_json)
513
+ return ToolCallPart(tool_name=result_tool.name, args=output_json)
515
514
 
516
515
  @staticmethod
517
516
  def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
@@ -569,7 +568,7 @@ def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPa
569
568
  tool_call_id = tool_call.id or None
570
569
  func_call = tool_call.function
571
570
 
572
- return ToolCallPart.from_raw_args(func_call.name, func_call.arguments, tool_call_id)
571
+ return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
573
572
 
574
573
 
575
574
  def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
@@ -600,7 +599,7 @@ def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None
600
599
  elif isinstance(content, str):
601
600
  result = content
602
601
 
603
- # Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and reponses`)
602
+ # Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and responses`)
604
603
  if result and len(result) == 0:
605
604
  result = None
606
605
 
@@ -1,11 +1,12 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import os
3
4
  from collections.abc import AsyncIterable, AsyncIterator, Iterable
4
5
  from contextlib import asynccontextmanager
5
6
  from dataclasses import dataclass, field
6
7
  from datetime import datetime, timezone
7
8
  from itertools import chain
8
- from typing import Literal, Union, overload
9
+ from typing import Literal, Union, cast, overload
9
10
 
10
11
  from httpx import AsyncClient as AsyncHTTPClient
11
12
  from typing_extensions import assert_never
@@ -48,12 +49,18 @@ except ImportError as _import_error:
48
49
  OpenAIModelName = Union[ChatModel, str]
49
50
  """
50
51
  Using this more broad type for the model name instead of the ChatModel definition
51
- allows this model to be used more easily with other model types (ie, Ollama)
52
+ allows this model to be used more easily with other model types (ie, Ollama, Deepseek)
52
53
  """
53
54
 
54
55
  OpenAISystemPromptRole = Literal['system', 'developer', 'user']
55
56
 
56
57
 
58
+ class OpenAIModelSettings(ModelSettings):
59
+ """Settings used for an OpenAI model request."""
60
+
61
+ # This class is a placeholder for any future openai-specific settings
62
+
63
+
57
64
  @dataclass(init=False)
58
65
  class OpenAIModel(Model):
59
66
  """A model that uses the OpenAI API.
@@ -95,7 +102,11 @@ class OpenAIModel(Model):
95
102
  In the future, this may be inferred from the model name.
96
103
  """
97
104
  self.model_name: OpenAIModelName = model_name
98
- if openai_client is not None:
105
+ # This is a workaround for the OpenAI client requiring an API key, whilst locally served,
106
+ # openai compatible models do not always need an API key.
107
+ if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None:
108
+ api_key = ''
109
+ elif openai_client is not None:
99
110
  assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
100
111
  assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
101
112
  assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
@@ -153,31 +164,31 @@ class OpenAIAgentModel(AgentModel):
153
164
  async def request(
154
165
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
155
166
  ) -> tuple[ModelResponse, usage.Usage]:
156
- response = await self._completions_create(messages, False, model_settings)
167
+ response = await self._completions_create(messages, False, cast(OpenAIModelSettings, model_settings or {}))
157
168
  return self._process_response(response), _map_usage(response)
158
169
 
159
170
  @asynccontextmanager
160
171
  async def request_stream(
161
172
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
162
173
  ) -> AsyncIterator[StreamedResponse]:
163
- response = await self._completions_create(messages, True, model_settings)
174
+ response = await self._completions_create(messages, True, cast(OpenAIModelSettings, model_settings or {}))
164
175
  async with response:
165
176
  yield await self._process_streamed_response(response)
166
177
 
167
178
  @overload
168
179
  async def _completions_create(
169
- self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
180
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: OpenAIModelSettings
170
181
  ) -> AsyncStream[ChatCompletionChunk]:
171
182
  pass
172
183
 
173
184
  @overload
174
185
  async def _completions_create(
175
- self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
186
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: OpenAIModelSettings
176
187
  ) -> chat.ChatCompletion:
177
188
  pass
178
189
 
179
190
  async def _completions_create(
180
- self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
191
+ self, messages: list[ModelMessage], stream: bool, model_settings: OpenAIModelSettings
181
192
  ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
182
193
  # standalone function to make it easier to override
183
194
  if not self.tools:
@@ -189,13 +200,11 @@ class OpenAIAgentModel(AgentModel):
189
200
 
190
201
  openai_messages = list(chain(*(self._map_message(m) for m in messages)))
191
202
 
192
- model_settings = model_settings or {}
193
-
194
203
  return await self.client.chat.completions.create(
195
204
  model=self.model_name,
196
205
  messages=openai_messages,
197
206
  n=1,
198
- parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
207
+ parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
199
208
  tools=self.tools or NOT_GIVEN,
200
209
  tool_choice=tool_choice or NOT_GIVEN,
201
210
  stream=stream,
@@ -204,6 +213,10 @@ class OpenAIAgentModel(AgentModel):
204
213
  temperature=model_settings.get('temperature', NOT_GIVEN),
205
214
  top_p=model_settings.get('top_p', NOT_GIVEN),
206
215
  timeout=model_settings.get('timeout', NOT_GIVEN),
216
+ seed=model_settings.get('seed', NOT_GIVEN),
217
+ presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
218
+ frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
219
+ logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
207
220
  )
208
221
 
209
222
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
@@ -215,7 +228,7 @@ class OpenAIAgentModel(AgentModel):
215
228
  items.append(TextPart(choice.message.content))
216
229
  if choice.message.tool_calls is not None:
217
230
  for c in choice.message.tool_calls:
218
- items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
231
+ items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
219
232
  return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
220
233
 
221
234
  async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
@@ -12,7 +12,6 @@ import pydantic_core
12
12
 
13
13
  from .. import _utils
14
14
  from ..messages import (
15
- ArgsJson,
16
15
  ModelMessage,
17
16
  ModelRequest,
18
17
  ModelResponse,
@@ -34,6 +33,20 @@ from . import (
34
33
  from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
35
34
 
36
35
 
36
+ @dataclass
37
+ class _TextResult:
38
+ """A private wrapper class to tag a result that came from the custom_result_text field."""
39
+
40
+ value: str | None
41
+
42
+
43
+ @dataclass
44
+ class _FunctionToolResult:
45
+ """A wrapper class to tag a result that came from the custom_result_args field."""
46
+
47
+ value: Any | None
48
+
49
+
37
50
  @dataclass
38
51
  class TestModel(Model):
39
52
  """A model specifically for testing purposes.
@@ -53,7 +66,7 @@ class TestModel(Model):
53
66
  call_tools: list[str] | Literal['all'] = 'all'
54
67
  """List of tools to call. If `'all'`, all tools will be called."""
55
68
  custom_result_text: str | None = None
56
- """If set, this text is return as the final result."""
69
+ """If set, this text is returned as the final result."""
57
70
  custom_result_args: Any | None = None
58
71
  """If set, these args will be passed to the result tool."""
59
72
  seed: int = 0
@@ -95,21 +108,21 @@ class TestModel(Model):
95
108
  if self.custom_result_text is not None:
96
109
  assert allow_text_result, 'Plain response not allowed, but `custom_result_text` is set.'
97
110
  assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
98
- result: _utils.Either[str | None, Any | None] = _utils.Either(left=self.custom_result_text)
111
+ result: _TextResult | _FunctionToolResult = _TextResult(self.custom_result_text)
99
112
  elif self.custom_result_args is not None:
100
113
  assert result_tools is not None, 'No result tools provided, but `custom_result_args` is set.'
101
114
  result_tool = result_tools[0]
102
115
 
103
116
  if k := result_tool.outer_typed_dict_key:
104
- result = _utils.Either(right={k: self.custom_result_args})
117
+ result = _FunctionToolResult({k: self.custom_result_args})
105
118
  else:
106
- result = _utils.Either(right=self.custom_result_args)
119
+ result = _FunctionToolResult(self.custom_result_args)
107
120
  elif allow_text_result:
108
- result = _utils.Either(left=None)
121
+ result = _TextResult(None)
109
122
  elif result_tools:
110
- result = _utils.Either(right=None)
123
+ result = _FunctionToolResult(None)
111
124
  else:
112
- result = _utils.Either(left=None)
125
+ result = _TextResult(None)
113
126
 
114
127
  return TestAgentModel(tool_calls, result, result_tools, self.seed)
115
128
 
@@ -126,7 +139,7 @@ class TestAgentModel(AgentModel):
126
139
 
127
140
  tool_calls: list[tuple[str, ToolDefinition]]
128
141
  # left means the text is plain text; right means it's a function call
129
- result: _utils.Either[str | None, Any | None]
142
+ result: _TextResult | _FunctionToolResult
130
143
  result_tools: list[ToolDefinition]
131
144
  seed: int
132
145
  model_name: str = 'test'
@@ -152,7 +165,7 @@ class TestAgentModel(AgentModel):
152
165
  # if there are tools, the first thing we want to do is call all of them
153
166
  if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
154
167
  return ModelResponse(
155
- parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls],
168
+ parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in self.tool_calls],
156
169
  model_name=self.model_name,
157
170
  )
158
171
 
@@ -166,7 +179,7 @@ class TestAgentModel(AgentModel):
166
179
  # Handle retries for both function tools and result tools
167
180
  # Check function tools first
168
181
  retry_parts: list[ModelResponsePart] = [
169
- ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
182
+ ToolCallPart(name, self.gen_tool_args(args))
170
183
  for name, args in self.tool_calls
171
184
  if name in new_retry_names
172
185
  ]
@@ -174,15 +187,20 @@ class TestAgentModel(AgentModel):
174
187
  if self.result_tools:
175
188
  retry_parts.extend(
176
189
  [
177
- ToolCallPart.from_raw_args(tool.name, self.gen_tool_args(tool))
190
+ ToolCallPart(
191
+ tool.name,
192
+ self.result.value
193
+ if isinstance(self.result, _FunctionToolResult) and self.result.value is not None
194
+ else self.gen_tool_args(tool),
195
+ )
178
196
  for tool in self.result_tools
179
197
  if tool.name in new_retry_names
180
198
  ]
181
199
  )
182
200
  return ModelResponse(parts=retry_parts, model_name=self.model_name)
183
201
 
184
- if response_text := self.result.left:
185
- if response_text.value is None:
202
+ if isinstance(self.result, _TextResult):
203
+ if (response_text := self.result.value) is None:
186
204
  # build up details of tool responses
187
205
  output: dict[str, Any] = {}
188
206
  for message in messages:
@@ -197,20 +215,18 @@ class TestAgentModel(AgentModel):
197
215
  else:
198
216
  return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.model_name)
199
217
  else:
200
- return ModelResponse(parts=[TextPart(response_text.value)], model_name=self.model_name)
218
+ return ModelResponse(parts=[TextPart(response_text)], model_name=self.model_name)
201
219
  else:
202
220
  assert self.result_tools, 'No result tools provided'
203
- custom_result_args = self.result.right
221
+ custom_result_args = self.result.value
204
222
  result_tool = self.result_tools[self.seed % len(self.result_tools)]
205
223
  if custom_result_args is not None:
206
224
  return ModelResponse(
207
- parts=[ToolCallPart.from_raw_args(result_tool.name, custom_result_args)], model_name=self.model_name
225
+ parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self.model_name
208
226
  )
209
227
  else:
210
228
  response_args = self.gen_tool_args(result_tool)
211
- return ModelResponse(
212
- parts=[ToolCallPart.from_raw_args(result_tool.name, response_args)], model_name=self.model_name
213
- )
229
+ return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self.model_name)
214
230
 
215
231
 
216
232
  @dataclass
@@ -241,9 +257,8 @@ class TestStreamedResponse(StreamedResponse):
241
257
  self._usage += _get_string_usage(word)
242
258
  yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
243
259
  else:
244
- args = part.args.args_json if isinstance(part.args, ArgsJson) else part.args.args_dict
245
260
  yield self._parts_manager.handle_tool_call_part(
246
- vendor_part_id=i, tool_name=part.tool_name, args=args, tool_call_id=part.tool_call_id
261
+ vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
247
262
  )
248
263
 
249
264
  def timestamp(self) -> datetime:
pydantic_ai/result.py CHANGED
@@ -46,7 +46,7 @@ A function that always takes and returns the same type of data (which is the res
46
46
  * may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
47
47
  * may or may not be async
48
48
 
49
- Usage `ResultValidatorFunc[AgentDeps, T]`.
49
+ Usage `ResultValidatorFunc[AgentDepsT, T]`.
50
50
  """
51
51
 
52
52
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
pydantic_ai/settings.py CHANGED
@@ -80,11 +80,56 @@ class ModelSettings(TypedDict, total=False):
80
80
  """Whether to allow parallel tool calls.
81
81
 
82
82
  Supported by:
83
- * OpenAI
83
+
84
+ * OpenAI (some models, not o1)
84
85
  * Groq
85
86
  * Anthropic
86
87
  """
87
88
 
89
+ seed: int
90
+ """The random seed to use for the model, theoretically allowing for deterministic results.
91
+
92
+ Supported by:
93
+
94
+ * OpenAI
95
+ * Groq
96
+ * Cohere
97
+ * Mistral
98
+ """
99
+
100
+ presence_penalty: float
101
+ """Penalize new tokens based on whether they have appeared in the text so far.
102
+
103
+ Supported by:
104
+
105
+ * OpenAI
106
+ * Groq
107
+ * Cohere
108
+ * Gemini
109
+ * Mistral
110
+ """
111
+
112
+ frequency_penalty: float
113
+ """Penalize new tokens based on their existing frequency in the text so far.
114
+
115
+ Supported by:
116
+
117
+ * OpenAI
118
+ * Groq
119
+ * Cohere
120
+ * Gemini
121
+ * Mistral
122
+ """
123
+
124
+ logit_bias: dict[str, int]
125
+ """Modify the likelihood of specified tokens appearing in the completion.
126
+
127
+ Supported by:
128
+
129
+ * OpenAI
130
+ * Groq
131
+ """
132
+
88
133
 
89
134
  def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | None) -> ModelSettings | None:
90
135
  """Merge two sets of model settings, preferring the overrides.
pydantic_ai/tools.py CHANGED
@@ -79,13 +79,13 @@ SystemPromptFunc = Union[
79
79
  ]
80
80
  """A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
81
81
 
82
- Usage `SystemPromptFunc[AgentDeps]`.
82
+ Usage `SystemPromptFunc[AgentDepsT]`.
83
83
  """
84
84
 
85
85
  ToolFuncContext = Callable[Concatenate[RunContext[AgentDepsT], ToolParams], Any]
86
86
  """A tool function that takes `RunContext` as the first argument.
87
87
 
88
- Usage `ToolContextFunc[AgentDeps, ToolParams]`.
88
+ Usage `ToolContextFunc[AgentDepsT, ToolParams]`.
89
89
  """
90
90
  ToolFuncPlain = Callable[ToolParams, Any]
91
91
  """A tool function that does not take `RunContext` as the first argument.
@@ -98,7 +98,7 @@ ToolFuncEither = Union[ToolFuncContext[AgentDepsT, ToolParams], ToolFuncPlain[To
98
98
  This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
99
99
  [`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
100
100
 
101
- Usage `ToolFuncEither[AgentDeps, ToolParams]`.
101
+ Usage `ToolFuncEither[AgentDepsT, ToolParams]`.
102
102
  """
103
103
  ToolPrepareFunc: TypeAlias = 'Callable[[RunContext[AgentDepsT], ToolDefinition], Awaitable[ToolDefinition | None]]'
104
104
  """Definition of a function that can prepare a tool definition at call time.
@@ -125,7 +125,7 @@ def hitchhiker(ctx: RunContext[int], answer: str) -> str:
125
125
  hitchhiker = Tool(hitchhiker, prepare=only_if_42)
126
126
  ```
127
127
 
128
- Usage `ToolPrepareFunc[AgentDeps]`.
128
+ Usage `ToolPrepareFunc[AgentDepsT]`.
129
129
  """
130
130
 
131
131
  DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto']
@@ -158,6 +158,9 @@ class Tool(Generic[AgentDepsT]):
158
158
  _var_positional_field: str | None = field(init=False)
159
159
  _validator: SchemaValidator = field(init=False, repr=False)
160
160
  _parameters_json_schema: ObjectJsonSchema = field(init=False)
161
+
162
+ # TODO: Move this state off the Tool class, which is otherwise stateless.
163
+ # This should be tracked inside a specific agent run, not the tool.
161
164
  current_retry: int = field(default=0, init=False)
162
165
 
163
166
  def __init__(
@@ -261,13 +264,13 @@ class Tool(Generic[AgentDepsT]):
261
264
 
262
265
  async def run(
263
266
  self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT]
264
- ) -> _messages.ModelRequestPart:
267
+ ) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
265
268
  """Run the tool function asynchronously."""
266
269
  try:
267
- if isinstance(message.args, _messages.ArgsJson):
268
- args_dict = self._validator.validate_json(message.args.args_json)
270
+ if isinstance(message.args, str):
271
+ args_dict = self._validator.validate_json(message.args)
269
272
  else:
270
- args_dict = self._validator.validate_python(message.args.args_dict)
273
+ args_dict = self._validator.validate_python(message.args)
271
274
  except ValidationError as e:
272
275
  return self._on_error(e, message)
273
276
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.20
3
+ Version: 0.0.22
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -28,13 +28,12 @@ Requires-Dist: eval-type-backport>=0.2.0
28
28
  Requires-Dist: griffe>=1.3.2
29
29
  Requires-Dist: httpx>=0.27
30
30
  Requires-Dist: logfire-api>=1.2.0
31
+ Requires-Dist: pydantic-graph==0.0.22
31
32
  Requires-Dist: pydantic>=2.10
32
33
  Provides-Extra: anthropic
33
34
  Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
34
35
  Provides-Extra: cohere
35
36
  Requires-Dist: cohere>=5.13.11; extra == 'cohere'
36
- Provides-Extra: graph
37
- Requires-Dist: pydantic-graph==0.0.20; extra == 'graph'
38
37
  Provides-Extra: groq
39
38
  Requires-Dist: groq>=0.12.0; extra == 'groq'
40
39
  Provides-Extra: logfire