pydantic-ai-slim 0.0.20__tar.gz → 0.0.21__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (31) hide show
  1. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/PKG-INFO +2 -2
  2. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/_parts_manager.py +1 -1
  3. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/_result.py +3 -7
  4. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/_utils.py +1 -56
  5. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/agent.py +34 -30
  6. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/messages.py +21 -46
  7. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/models/__init__.py +100 -57
  8. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/models/anthropic.py +17 -10
  9. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/models/cohere.py +37 -25
  10. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/models/gemini.py +20 -6
  11. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/models/groq.py +19 -17
  12. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/models/mistral.py +22 -23
  13. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/models/openai.py +19 -11
  14. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/models/test.py +37 -22
  15. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/result.py +1 -1
  16. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/settings.py +41 -1
  17. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/tools.py +11 -8
  18. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pyproject.toml +2 -2
  19. pydantic_ai_slim-0.0.20/pydantic_ai/models/ollama.py +0 -123
  20. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/.gitignore +0 -0
  21. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/README.md +0 -0
  22. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/__init__.py +0 -0
  23. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/_griffe.py +0 -0
  24. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/_pydantic.py +0 -0
  25. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/_system_prompt.py +0 -0
  26. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/exceptions.py +0 -0
  27. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/format_as_xml.py +0 -0
  28. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/models/function.py +0 -0
  29. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/models/vertexai.py +0 -0
  30. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/py.typed +0 -0
  31. {pydantic_ai_slim-0.0.20 → pydantic_ai_slim-0.0.21}/pydantic_ai/usage.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.20
3
+ Version: 0.0.21
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -34,7 +34,7 @@ Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
34
34
  Provides-Extra: cohere
35
35
  Requires-Dist: cohere>=5.13.11; extra == 'cohere'
36
36
  Provides-Extra: graph
37
- Requires-Dist: pydantic-graph==0.0.20; extra == 'graph'
37
+ Requires-Dist: pydantic-graph==0.0.21; extra == 'graph'
38
38
  Provides-Extra: groq
39
39
  Requires-Dist: groq>=0.12.0; extra == 'groq'
40
40
  Provides-Extra: logfire
@@ -221,7 +221,7 @@ class ModelResponsePartsManager:
221
221
  ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part
222
222
  has been added to the manager, or replaced an existing part.
223
223
  """
224
- new_part = ToolCallPart.from_raw_args(tool_name=tool_name, args=args, tool_call_id=tool_call_id)
224
+ new_part = ToolCallPart(tool_name=tool_name, args=args, tool_call_id=tool_call_id)
225
225
  if vendor_part_id is None:
226
226
  # vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list
227
227
  new_part_index = len(self._parts)
@@ -201,14 +201,10 @@ class ResultTool(Generic[ResultDataT]):
201
201
  """
202
202
  try:
203
203
  pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
204
- if isinstance(tool_call.args, _messages.ArgsJson):
205
- result = self.type_adapter.validate_json(
206
- tool_call.args.args_json or '', experimental_allow_partial=pyd_allow_partial
207
- )
204
+ if isinstance(tool_call.args, str):
205
+ result = self.type_adapter.validate_json(tool_call.args, experimental_allow_partial=pyd_allow_partial)
208
206
  else:
209
- result = self.type_adapter.validate_python(
210
- tool_call.args.args_dict, experimental_allow_partial=pyd_allow_partial
211
- )
207
+ result = self.type_adapter.validate_python(tool_call.args, experimental_allow_partial=pyd_allow_partial)
212
208
  except ValidationError as e:
213
209
  if wrap_validation_errors:
214
210
  m = _messages.RetryPromptPart(
@@ -8,7 +8,7 @@ from dataclasses import dataclass, is_dataclass
8
8
  from datetime import datetime, timezone
9
9
  from functools import partial
10
10
  from types import GenericAlias
11
- from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload
11
+ from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
12
12
 
13
13
  from pydantic import BaseModel
14
14
  from pydantic.json_schema import JsonSchemaValue
@@ -79,61 +79,6 @@ def is_set(t_or_unset: T | Unset) -> TypeGuard[T]:
79
79
  return t_or_unset is not UNSET
80
80
 
81
81
 
82
- Left = TypeVar('Left')
83
- Right = TypeVar('Right')
84
-
85
-
86
- class Either(Generic[Left, Right]):
87
- """Two member Union that records which member was set, this is analogous to Rust enums with two variants.
88
-
89
- Usage:
90
-
91
- ```python
92
- if left_thing := either.left:
93
- use_left(left_thing.value)
94
- else:
95
- use_right(either.right)
96
- ```
97
- """
98
-
99
- __slots__ = '_left', '_right'
100
-
101
- @overload
102
- def __init__(self, *, left: Left) -> None: ...
103
-
104
- @overload
105
- def __init__(self, *, right: Right) -> None: ...
106
-
107
- def __init__(self, left: Left | Unset = UNSET, right: Right | Unset = UNSET) -> None:
108
- if left is not UNSET:
109
- assert right is UNSET, '`Either` must receive exactly one argument - `left` or `right`'
110
- self._left: Option[Left] = Some(cast(Left, left))
111
- else:
112
- assert right is not UNSET, '`Either` must receive exactly one argument - `left` or `right`'
113
- self._left = None
114
- self._right = cast(Right, right)
115
-
116
- @property
117
- def left(self) -> Option[Left]:
118
- return self._left
119
-
120
- @property
121
- def right(self) -> Right:
122
- return self._right
123
-
124
- def is_left(self) -> bool:
125
- return self._left is not None
126
-
127
- def whichever(self) -> Left | Right:
128
- return self._left.value if self._left is not None else self.right
129
-
130
- def __repr__(self):
131
- if left := self._left:
132
- return f'Either(left={left.value!r})'
133
- else:
134
- return f'Either(right={self.right!r})'
135
-
136
-
137
82
  @asynccontextmanager
138
83
  async def group_by_temporal(
139
84
  aiterable: AsyncIterable[T], soft_max_interval: float | None
@@ -60,7 +60,7 @@ EndStrategy = Literal['early', 'exhaustive']
60
60
  - `'early'`: Stop processing other tool calls once a final result is found
61
61
  - `'exhaustive'`: Process all tool calls even after finding a final result
62
62
  """
63
- RunResultData = TypeVar('RunResultData')
63
+ RunResultDataT = TypeVar('RunResultDataT')
64
64
  """Type variable for the result data of a run where `result_type` was customized on the run call."""
65
65
 
66
66
 
@@ -214,7 +214,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
214
214
  self,
215
215
  user_prompt: str,
216
216
  *,
217
- result_type: type[RunResultData],
217
+ result_type: type[RunResultDataT],
218
218
  message_history: list[_messages.ModelMessage] | None = None,
219
219
  model: models.Model | models.KnownModelName | None = None,
220
220
  deps: AgentDepsT = None,
@@ -222,7 +222,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
222
222
  usage_limits: _usage.UsageLimits | None = None,
223
223
  usage: _usage.Usage | None = None,
224
224
  infer_name: bool = True,
225
- ) -> result.RunResult[RunResultData]: ...
225
+ ) -> result.RunResult[RunResultDataT]: ...
226
226
 
227
227
  async def run(
228
228
  self,
@@ -234,7 +234,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
234
234
  model_settings: ModelSettings | None = None,
235
235
  usage_limits: _usage.UsageLimits | None = None,
236
236
  usage: _usage.Usage | None = None,
237
- result_type: type[RunResultData] | None = None,
237
+ result_type: type[RunResultDataT] | None = None,
238
238
  infer_name: bool = True,
239
239
  ) -> result.RunResult[Any]:
240
240
  """Run the agent with a user prompt in async mode.
@@ -352,7 +352,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
352
352
  self,
353
353
  user_prompt: str,
354
354
  *,
355
- result_type: type[RunResultData] | None,
355
+ result_type: type[RunResultDataT] | None,
356
356
  message_history: list[_messages.ModelMessage] | None = None,
357
357
  model: models.Model | models.KnownModelName | None = None,
358
358
  deps: AgentDepsT = None,
@@ -360,13 +360,13 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
360
360
  usage_limits: _usage.UsageLimits | None = None,
361
361
  usage: _usage.Usage | None = None,
362
362
  infer_name: bool = True,
363
- ) -> result.RunResult[RunResultData]: ...
363
+ ) -> result.RunResult[RunResultDataT]: ...
364
364
 
365
365
  def run_sync(
366
366
  self,
367
367
  user_prompt: str,
368
368
  *,
369
- result_type: type[RunResultData] | None = None,
369
+ result_type: type[RunResultDataT] | None = None,
370
370
  message_history: list[_messages.ModelMessage] | None = None,
371
371
  model: models.Model | models.KnownModelName | None = None,
372
372
  deps: AgentDepsT = None,
@@ -442,7 +442,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
442
442
  self,
443
443
  user_prompt: str,
444
444
  *,
445
- result_type: type[RunResultData],
445
+ result_type: type[RunResultDataT],
446
446
  message_history: list[_messages.ModelMessage] | None = None,
447
447
  model: models.Model | models.KnownModelName | None = None,
448
448
  deps: AgentDepsT = None,
@@ -450,14 +450,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
450
450
  usage_limits: _usage.UsageLimits | None = None,
451
451
  usage: _usage.Usage | None = None,
452
452
  infer_name: bool = True,
453
- ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultData]]: ...
453
+ ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultDataT]]: ...
454
454
 
455
455
  @asynccontextmanager
456
456
  async def run_stream(
457
457
  self,
458
458
  user_prompt: str,
459
459
  *,
460
- result_type: type[RunResultData] | None = None,
460
+ result_type: type[RunResultDataT] | None = None,
461
461
  message_history: list[_messages.ModelMessage] | None = None,
462
462
  model: models.Model | models.KnownModelName | None = None,
463
463
  deps: AgentDepsT = None,
@@ -572,7 +572,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
572
572
  # there are result validators that might convert the result data from an overridden
573
573
  # `result_type` to a type that is not valid as such.
574
574
  result_validators = cast(
575
- list[_result.ResultValidator[AgentDepsT, RunResultData]], self._result_validators
575
+ list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators
576
576
  )
577
577
 
578
578
  yield result.StreamedRunResult(
@@ -999,7 +999,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
999
999
  return model_
1000
1000
 
1001
1001
  async def _prepare_model(
1002
- self, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultData] | None
1002
+ self, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultDataT] | None
1003
1003
  ) -> models.AgentModel:
1004
1004
  """Build tools and create an agent model."""
1005
1005
  function_tools: list[ToolDefinition] = []
@@ -1035,8 +1035,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1035
1035
  )
1036
1036
 
1037
1037
  def _prepare_result_schema(
1038
- self, result_type: type[RunResultData] | None
1039
- ) -> _result.ResultSchema[RunResultData] | None:
1038
+ self, result_type: type[RunResultDataT] | None
1039
+ ) -> _result.ResultSchema[RunResultDataT] | None:
1040
1040
  if result_type is not None:
1041
1041
  if self._result_validators:
1042
1042
  raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
@@ -1053,7 +1053,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1053
1053
  run_context: RunContext[AgentDepsT],
1054
1054
  ) -> list[_messages.ModelMessage]:
1055
1055
  try:
1056
- ctx_messages = _messages_ctx_var.get()
1056
+ ctx_messages = get_captured_run_messages()
1057
1057
  except LookupError:
1058
1058
  messages: list[_messages.ModelMessage] = []
1059
1059
  else:
@@ -1080,8 +1080,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1080
1080
  self,
1081
1081
  model_response: _messages.ModelResponse,
1082
1082
  run_context: RunContext[AgentDepsT],
1083
- result_schema: _result.ResultSchema[RunResultData] | None,
1084
- ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
1083
+ result_schema: _result.ResultSchema[RunResultDataT] | None,
1084
+ ) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
1085
1085
  """Process a non-streamed response from the model.
1086
1086
 
1087
1087
  Returns:
@@ -1110,11 +1110,11 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1110
1110
  raise exceptions.UnexpectedModelBehavior('Received empty model response')
1111
1111
 
1112
1112
  async def _handle_text_response(
1113
- self, text: str, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultData] | None
1114
- ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
1113
+ self, text: str, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultDataT] | None
1114
+ ) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
1115
1115
  """Handle a plain text response from the model for non-streaming responses."""
1116
1116
  if self._allow_text_result(result_schema):
1117
- result_data_input = cast(RunResultData, text)
1117
+ result_data_input = cast(RunResultDataT, text)
1118
1118
  try:
1119
1119
  result_data = await self._validate_result(result_data_input, run_context, None)
1120
1120
  except _result.ToolRetryError as e:
@@ -1133,13 +1133,13 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1133
1133
  self,
1134
1134
  tool_calls: list[_messages.ToolCallPart],
1135
1135
  run_context: RunContext[AgentDepsT],
1136
- result_schema: _result.ResultSchema[RunResultData] | None,
1137
- ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
1136
+ result_schema: _result.ResultSchema[RunResultDataT] | None,
1137
+ ) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
1138
1138
  """Handle a structured response containing tool calls from the model for non-streaming responses."""
1139
1139
  assert tool_calls, 'Expected at least one tool call'
1140
1140
 
1141
1141
  # first look for the result tool call
1142
- final_result: _MarkFinalResult[RunResultData] | None = None
1142
+ final_result: _MarkFinalResult[RunResultDataT] | None = None
1143
1143
 
1144
1144
  parts: list[_messages.ModelRequestPart] = []
1145
1145
  if result_schema is not None:
@@ -1168,7 +1168,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1168
1168
  tool_calls: list[_messages.ToolCallPart],
1169
1169
  result_tool_name: str | None,
1170
1170
  run_context: RunContext[AgentDepsT],
1171
- result_schema: _result.ResultSchema[RunResultData] | None,
1171
+ result_schema: _result.ResultSchema[RunResultDataT] | None,
1172
1172
  ) -> list[_messages.ModelRequestPart]:
1173
1173
  """Process function (non-result) tool calls in parallel.
1174
1174
 
@@ -1227,7 +1227,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1227
1227
  self,
1228
1228
  streamed_response: models.StreamedResponse,
1229
1229
  run_context: RunContext[AgentDepsT],
1230
- result_schema: _result.ResultSchema[RunResultData] | None,
1230
+ result_schema: _result.ResultSchema[RunResultDataT] | None,
1231
1231
  ) -> _MarkFinalResult[models.StreamedResponse] | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]:
1232
1232
  """Process a streamed response from the model.
1233
1233
 
@@ -1282,15 +1282,15 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1282
1282
 
1283
1283
  async def _validate_result(
1284
1284
  self,
1285
- result_data: RunResultData,
1285
+ result_data: RunResultDataT,
1286
1286
  run_context: RunContext[AgentDepsT],
1287
1287
  tool_call: _messages.ToolCallPart | None,
1288
- ) -> RunResultData:
1288
+ ) -> RunResultDataT:
1289
1289
  if self._result_validators:
1290
1290
  agent_result_data = cast(ResultDataT, result_data)
1291
1291
  for validator in self._result_validators:
1292
1292
  agent_result_data = await validator.validate(agent_result_data, tool_call, run_context)
1293
- return cast(RunResultData, agent_result_data)
1293
+ return cast(RunResultDataT, agent_result_data)
1294
1294
  else:
1295
1295
  return result_data
1296
1296
 
@@ -1315,7 +1315,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1315
1315
  def _unknown_tool(
1316
1316
  self,
1317
1317
  tool_name: str,
1318
- result_schema: _result.ResultSchema[RunResultData] | None,
1318
+ result_schema: _result.ResultSchema[RunResultDataT] | None,
1319
1319
  ) -> _messages.RetryPromptPart:
1320
1320
  names = list(self._function_tools.keys())
1321
1321
  if result_schema:
@@ -1358,7 +1358,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1358
1358
  return
1359
1359
 
1360
1360
  @staticmethod
1361
- def _allow_text_result(result_schema: _result.ResultSchema[RunResultData] | None) -> bool:
1361
+ def _allow_text_result(result_schema: _result.ResultSchema[RunResultDataT] | None) -> bool:
1362
1362
  return result_schema is None or result_schema.allow_text_result
1363
1363
 
1364
1364
  @property
@@ -1413,6 +1413,10 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
1413
1413
  _messages_ctx_var.reset(token)
1414
1414
 
1415
1415
 
1416
+ def get_captured_run_messages() -> _RunMessages:
1417
+ return _messages_ctx_var.get()
1418
+
1419
+
1416
1420
  @dataclasses.dataclass
1417
1421
  class _MarkFinalResult(Generic[ResultDataT]):
1418
1422
  """Marker class to indicate that the result is the final result.
@@ -6,7 +6,6 @@ from typing import Annotated, Any, Literal, Union, cast, overload
6
6
 
7
7
  import pydantic
8
8
  import pydantic_core
9
- from typing_extensions import Self, assert_never
10
9
 
11
10
  from ._utils import now_utc as _now_utc
12
11
  from .exceptions import UnexpectedModelBehavior
@@ -168,22 +167,6 @@ class TextPart:
168
167
  return bool(self.content)
169
168
 
170
169
 
171
- @dataclass
172
- class ArgsJson:
173
- """Tool arguments as a JSON string."""
174
-
175
- args_json: str
176
- """A JSON string of arguments."""
177
-
178
-
179
- @dataclass
180
- class ArgsDict:
181
- """Tool arguments as a Python dictionary."""
182
-
183
- args_dict: dict[str, Any]
184
- """A python dictionary of arguments."""
185
-
186
-
187
170
  @dataclass
188
171
  class ToolCallPart:
189
172
  """A tool call from a model."""
@@ -191,10 +174,10 @@ class ToolCallPart:
191
174
  tool_name: str
192
175
  """The name of the tool to call."""
193
176
 
194
- args: ArgsJson | ArgsDict
177
+ args: str | dict[str, Any]
195
178
  """The arguments to pass to the tool.
196
179
 
197
- Either as JSON or a Python dictionary depending on how data was returned.
180
+ This is stored either as a JSON string or a Python dictionary depending on how data was received.
198
181
  """
199
182
 
200
183
  tool_call_id: str | None = None
@@ -203,24 +186,14 @@ class ToolCallPart:
203
186
  part_kind: Literal['tool-call'] = 'tool-call'
204
187
  """Part type identifier, this is available on all parts as a discriminator."""
205
188
 
206
- @classmethod
207
- def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
208
- """Create a `ToolCallPart` from raw arguments, converting them to `ArgsJson` or `ArgsDict`."""
209
- if isinstance(args, str):
210
- return cls(tool_name, ArgsJson(args), tool_call_id)
211
- elif isinstance(args, dict):
212
- return cls(tool_name, ArgsDict(args), tool_call_id)
213
- else:
214
- assert_never(args)
215
-
216
189
  def args_as_dict(self) -> dict[str, Any]:
217
190
  """Return the arguments as a Python dictionary.
218
191
 
219
192
  This is just for convenience with models that require dicts as input.
220
193
  """
221
- if isinstance(self.args, ArgsDict):
222
- return self.args.args_dict
223
- args = pydantic_core.from_json(self.args.args_json)
194
+ if isinstance(self.args, dict):
195
+ return self.args
196
+ args = pydantic_core.from_json(self.args)
224
197
  assert isinstance(args, dict), 'args should be a dict'
225
198
  return cast(dict[str, Any], args)
226
199
 
@@ -229,16 +202,18 @@ class ToolCallPart:
229
202
 
230
203
  This is just for convenience with models that require JSON strings as input.
231
204
  """
232
- if isinstance(self.args, ArgsJson):
233
- return self.args.args_json
234
- return pydantic_core.to_json(self.args.args_dict).decode()
205
+ if isinstance(self.args, str):
206
+ return self.args
207
+ return pydantic_core.to_json(self.args).decode()
235
208
 
236
209
  def has_content(self) -> bool:
237
210
  """Return `True` if the arguments contain any data."""
238
- if isinstance(self.args, ArgsDict):
239
- return any(self.args.args_dict.values())
211
+ if isinstance(self.args, dict):
212
+ # TODO: This should probably return True if you have the value False, or 0, etc.
213
+ # It makes sense to me to ignore empty strings, but not sure about empty lists or dicts
214
+ return any(self.args.values())
240
215
  else:
241
- return bool(self.args.args_json)
216
+ return bool(self.args)
242
217
 
243
218
 
244
219
  ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
@@ -331,7 +306,7 @@ class ToolCallPartDelta:
331
306
  if self.tool_name_delta is None or self.args_delta is None:
332
307
  return None
333
308
 
334
- return ToolCallPart.from_raw_args(
309
+ return ToolCallPart(
335
310
  self.tool_name_delta,
336
311
  self.args_delta,
337
312
  self.tool_call_id,
@@ -396,7 +371,7 @@ class ToolCallPartDelta:
396
371
 
397
372
  # If we now have enough data to create a full ToolCallPart, do so
398
373
  if delta.tool_name_delta is not None and delta.args_delta is not None:
399
- return ToolCallPart.from_raw_args(
374
+ return ToolCallPart(
400
375
  delta.tool_name_delta,
401
376
  delta.args_delta,
402
377
  delta.tool_call_id,
@@ -412,15 +387,15 @@ class ToolCallPartDelta:
412
387
  part = replace(part, tool_name=tool_name)
413
388
 
414
389
  if isinstance(self.args_delta, str):
415
- if not isinstance(part.args, ArgsJson):
390
+ if not isinstance(part.args, str):
416
391
  raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})')
417
- updated_json = part.args.args_json + self.args_delta
418
- part = replace(part, args=ArgsJson(updated_json))
392
+ updated_json = part.args + self.args_delta
393
+ part = replace(part, args=updated_json)
419
394
  elif isinstance(self.args_delta, dict):
420
- if not isinstance(part.args, ArgsDict):
395
+ if not isinstance(part.args, dict):
421
396
  raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})')
422
- updated_dict = {**(part.args.args_dict or {}), **self.args_delta}
423
- part = replace(part, args=ArgsDict(updated_dict))
397
+ updated_dict = {**(part.args or {}), **self.args_delta}
398
+ part = replace(part, args=updated_dict)
424
399
 
425
400
  if self.tool_call_id:
426
401
  # Replace the tool_call_id entirely if given
@@ -12,9 +12,10 @@ from contextlib import asynccontextmanager, contextmanager
12
12
  from dataclasses import dataclass, field
13
13
  from datetime import datetime
14
14
  from functools import cache
15
- from typing import TYPE_CHECKING, Literal
15
+ from typing import TYPE_CHECKING
16
16
 
17
17
  import httpx
18
+ from typing_extensions import Literal
18
19
 
19
20
  from .._parts_manager import ModelResponsePartsManager
20
21
  from ..exceptions import UserError
@@ -27,58 +28,6 @@ if TYPE_CHECKING:
27
28
 
28
29
 
29
30
  KnownModelName = Literal[
30
- 'openai:gpt-4o',
31
- 'openai:gpt-4o-mini',
32
- 'openai:gpt-4-turbo',
33
- 'openai:gpt-4',
34
- 'openai:o1-preview',
35
- 'openai:o1-mini',
36
- 'openai:o1',
37
- 'openai:gpt-3.5-turbo',
38
- 'groq:llama-3.3-70b-versatile',
39
- 'groq:llama-3.1-70b-versatile',
40
- 'groq:llama3-groq-70b-8192-tool-use-preview',
41
- 'groq:llama3-groq-8b-8192-tool-use-preview',
42
- 'groq:llama-3.1-70b-specdec',
43
- 'groq:llama-3.1-8b-instant',
44
- 'groq:llama-3.2-1b-preview',
45
- 'groq:llama-3.2-3b-preview',
46
- 'groq:llama-3.2-11b-vision-preview',
47
- 'groq:llama-3.2-90b-vision-preview',
48
- 'groq:llama3-70b-8192',
49
- 'groq:llama3-8b-8192',
50
- 'groq:mixtral-8x7b-32768',
51
- 'groq:gemma2-9b-it',
52
- 'groq:gemma-7b-it',
53
- 'google-gla:gemini-1.5-flash',
54
- 'google-gla:gemini-1.5-pro',
55
- 'google-gla:gemini-2.0-flash-exp',
56
- 'google-vertex:gemini-1.5-flash',
57
- 'google-vertex:gemini-1.5-pro',
58
- 'google-vertex:gemini-2.0-flash-exp',
59
- 'mistral:mistral-small-latest',
60
- 'mistral:mistral-large-latest',
61
- 'mistral:codestral-latest',
62
- 'mistral:mistral-moderation-latest',
63
- 'ollama:codellama',
64
- 'ollama:deepseek-r1',
65
- 'ollama:gemma',
66
- 'ollama:gemma2',
67
- 'ollama:llama3',
68
- 'ollama:llama3.1',
69
- 'ollama:llama3.2',
70
- 'ollama:llama3.2-vision',
71
- 'ollama:llama3.3',
72
- 'ollama:mistral',
73
- 'ollama:mistral-nemo',
74
- 'ollama:mixtral',
75
- 'ollama:phi3',
76
- 'ollama:phi4',
77
- 'ollama:qwq',
78
- 'ollama:qwen',
79
- 'ollama:qwen2',
80
- 'ollama:qwen2.5',
81
- 'ollama:starcoder2',
82
31
  'anthropic:claude-3-5-haiku-latest',
83
32
  'anthropic:claude-3-5-sonnet-latest',
84
33
  'anthropic:claude-3-opus-latest',
@@ -98,6 +47,104 @@ KnownModelName = Literal[
98
47
  'cohere:command-r-plus-04-2024',
99
48
  'cohere:command-r-plus-08-2024',
100
49
  'cohere:command-r7b-12-2024',
50
+ 'google-gla:gemini-1.0-pro',
51
+ 'google-gla:gemini-1.5-flash',
52
+ 'google-gla:gemini-1.5-flash-8b',
53
+ 'google-gla:gemini-1.5-pro',
54
+ 'google-gla:gemini-2.0-flash-exp',
55
+ 'google-vertex:gemini-1.0-pro',
56
+ 'google-vertex:gemini-1.5-flash',
57
+ 'google-vertex:gemini-1.5-flash-8b',
58
+ 'google-vertex:gemini-1.5-pro',
59
+ 'google-vertex:gemini-2.0-flash-exp',
60
+ 'gpt-3.5-turbo',
61
+ 'gpt-3.5-turbo-0125',
62
+ 'gpt-3.5-turbo-0301',
63
+ 'gpt-3.5-turbo-0613',
64
+ 'gpt-3.5-turbo-1106',
65
+ 'gpt-3.5-turbo-16k',
66
+ 'gpt-3.5-turbo-16k-0613',
67
+ 'gpt-4',
68
+ 'gpt-4-0125-preview',
69
+ 'gpt-4-0314',
70
+ 'gpt-4-0613',
71
+ 'gpt-4-1106-preview',
72
+ 'gpt-4-32k',
73
+ 'gpt-4-32k-0314',
74
+ 'gpt-4-32k-0613',
75
+ 'gpt-4-turbo',
76
+ 'gpt-4-turbo-2024-04-09',
77
+ 'gpt-4-turbo-preview',
78
+ 'gpt-4-vision-preview',
79
+ 'gpt-4o',
80
+ 'gpt-4o-2024-05-13',
81
+ 'gpt-4o-2024-08-06',
82
+ 'gpt-4o-2024-11-20',
83
+ 'gpt-4o-audio-preview',
84
+ 'gpt-4o-audio-preview-2024-10-01',
85
+ 'gpt-4o-audio-preview-2024-12-17',
86
+ 'gpt-4o-mini',
87
+ 'gpt-4o-mini-2024-07-18',
88
+ 'gpt-4o-mini-audio-preview',
89
+ 'gpt-4o-mini-audio-preview-2024-12-17',
90
+ 'groq:gemma2-9b-it',
91
+ 'groq:llama-3.1-8b-instant',
92
+ 'groq:llama-3.2-11b-vision-preview',
93
+ 'groq:llama-3.2-1b-preview',
94
+ 'groq:llama-3.2-3b-preview',
95
+ 'groq:llama-3.2-90b-vision-preview',
96
+ 'groq:llama-3.3-70b-specdec',
97
+ 'groq:llama-3.3-70b-versatile',
98
+ 'groq:llama3-70b-8192',
99
+ 'groq:llama3-8b-8192',
100
+ 'groq:mixtral-8x7b-32768',
101
+ 'mistral:codestral-latest',
102
+ 'mistral:mistral-large-latest',
103
+ 'mistral:mistral-moderation-latest',
104
+ 'mistral:mistral-small-latest',
105
+ 'o1',
106
+ 'o1-2024-12-17',
107
+ 'o1-mini',
108
+ 'o1-mini-2024-09-12',
109
+ 'o1-preview',
110
+ 'o1-preview-2024-09-12',
111
+ 'openai:chatgpt-4o-latest',
112
+ 'openai:gpt-3.5-turbo',
113
+ 'openai:gpt-3.5-turbo-0125',
114
+ 'openai:gpt-3.5-turbo-0301',
115
+ 'openai:gpt-3.5-turbo-0613',
116
+ 'openai:gpt-3.5-turbo-1106',
117
+ 'openai:gpt-3.5-turbo-16k',
118
+ 'openai:gpt-3.5-turbo-16k-0613',
119
+ 'openai:gpt-4',
120
+ 'openai:gpt-4-0125-preview',
121
+ 'openai:gpt-4-0314',
122
+ 'openai:gpt-4-0613',
123
+ 'openai:gpt-4-1106-preview',
124
+ 'openai:gpt-4-32k',
125
+ 'openai:gpt-4-32k-0314',
126
+ 'openai:gpt-4-32k-0613',
127
+ 'openai:gpt-4-turbo',
128
+ 'openai:gpt-4-turbo-2024-04-09',
129
+ 'openai:gpt-4-turbo-preview',
130
+ 'openai:gpt-4-vision-preview',
131
+ 'openai:gpt-4o',
132
+ 'openai:gpt-4o-2024-05-13',
133
+ 'openai:gpt-4o-2024-08-06',
134
+ 'openai:gpt-4o-2024-11-20',
135
+ 'openai:gpt-4o-audio-preview',
136
+ 'openai:gpt-4o-audio-preview-2024-10-01',
137
+ 'openai:gpt-4o-audio-preview-2024-12-17',
138
+ 'openai:gpt-4o-mini',
139
+ 'openai:gpt-4o-mini-2024-07-18',
140
+ 'openai:gpt-4o-mini-audio-preview',
141
+ 'openai:gpt-4o-mini-audio-preview-2024-12-17',
142
+ 'openai:o1',
143
+ 'openai:o1-2024-12-17',
144
+ 'openai:o1-mini',
145
+ 'openai:o1-mini-2024-09-12',
146
+ 'openai:o1-preview',
147
+ 'openai:o1-preview-2024-09-12',
101
148
  'test',
102
149
  ]
103
150
  """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -291,10 +338,6 @@ def infer_model(model: Model | KnownModelName) -> Model:
291
338
  from .mistral import MistralModel
292
339
 
293
340
  return MistralModel(model[8:])
294
- elif model.startswith('ollama:'):
295
- from .ollama import OllamaModel
296
-
297
- return OllamaModel(model[7:])
298
341
  elif model.startswith('anthropic'):
299
342
  from .anthropic import AnthropicModel
300
343