pydantic-ai-slim 0.0.17__tar.gz → 0.0.18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/PKG-INFO +1 -1
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/_griffe.py +13 -1
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/_system_prompt.py +1 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/agent.py +256 -56
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/messages.py +6 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/models/__init__.py +28 -10
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/models/anthropic.py +1 -1
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/models/gemini.py +10 -3
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/models/vertexai.py +1 -1
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pyproject.toml +2 -1
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/.gitignore +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/README.md +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/_pydantic.py +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/_result.py +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/_utils.py +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/models/function.py +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/models/groq.py +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/models/mistral.py +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/models/ollama.py +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/models/openai.py +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/tools.py +0 -0
- {pydantic_ai_slim-0.0.17 → pydantic_ai_slim-0.0.18}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
import re
|
|
5
|
+
from contextlib import contextmanager
|
|
4
6
|
from inspect import Signature
|
|
5
7
|
from typing import Any, Callable, Literal, cast
|
|
6
8
|
|
|
@@ -25,7 +27,8 @@ def doc_descriptions(
|
|
|
25
27
|
parent = cast(GriffeObject, sig)
|
|
26
28
|
|
|
27
29
|
docstring = Docstring(doc, lineno=1, parser=style or _infer_docstring_style(doc), parent=parent)
|
|
28
|
-
|
|
30
|
+
with _disable_griffe_logging():
|
|
31
|
+
sections = docstring.parse()
|
|
29
32
|
|
|
30
33
|
params = {}
|
|
31
34
|
if parameters := next((p for p in sections if p.kind == DocstringSectionKind.parameters), None):
|
|
@@ -125,3 +128,12 @@ _docstring_style_patterns: list[tuple[str, list[str], DocstringStyle]] = [
|
|
|
125
128
|
'numpy',
|
|
126
129
|
),
|
|
127
130
|
]
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@contextmanager
|
|
134
|
+
def _disable_griffe_logging():
|
|
135
|
+
# Hacky, but suggested here: https://github.com/mkdocstrings/griffe/issues/293#issuecomment-2167668117
|
|
136
|
+
old_level = logging.root.getEffectiveLevel()
|
|
137
|
+
logging.root.setLevel(logging.ERROR)
|
|
138
|
+
yield
|
|
139
|
+
logging.root.setLevel(old_level)
|
|
@@ -12,6 +12,7 @@ from .tools import AgentDeps, RunContext, SystemPromptFunc
|
|
|
12
12
|
@dataclass
|
|
13
13
|
class SystemPromptRunner(Generic[AgentDeps]):
|
|
14
14
|
function: SystemPromptFunc[AgentDeps]
|
|
15
|
+
dynamic: bool = False
|
|
15
16
|
_takes_ctx: bool = field(init=False)
|
|
16
17
|
_is_async: bool = field(init=False)
|
|
17
18
|
|
|
@@ -4,13 +4,13 @@ import asyncio
|
|
|
4
4
|
import dataclasses
|
|
5
5
|
import inspect
|
|
6
6
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
7
|
-
from contextlib import asynccontextmanager, contextmanager
|
|
7
|
+
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
|
|
8
8
|
from contextvars import ContextVar
|
|
9
9
|
from types import FrameType
|
|
10
10
|
from typing import Any, Callable, Generic, Literal, cast, final, overload
|
|
11
11
|
|
|
12
12
|
import logfire_api
|
|
13
|
-
from typing_extensions import assert_never, deprecated
|
|
13
|
+
from typing_extensions import TypeVar, assert_never, deprecated
|
|
14
14
|
|
|
15
15
|
from . import (
|
|
16
16
|
_result,
|
|
@@ -57,6 +57,8 @@ EndStrategy = Literal['early', 'exhaustive']
|
|
|
57
57
|
- `'early'`: Stop processing other tool calls once a final result is found
|
|
58
58
|
- `'exhaustive'`: Process all tool calls even after finding a final result
|
|
59
59
|
"""
|
|
60
|
+
RunResultData = TypeVar('RunResultData')
|
|
61
|
+
"""Type variable for the result data of a run where `result_type` was customized on the run call."""
|
|
60
62
|
|
|
61
63
|
|
|
62
64
|
@final
|
|
@@ -99,14 +101,17 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
99
101
|
Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
|
|
100
102
|
be merged with this value, with the runtime argument taking priority.
|
|
101
103
|
"""
|
|
102
|
-
|
|
104
|
+
_result_tool_name: str = dataclasses.field(repr=False)
|
|
105
|
+
_result_tool_description: str | None = dataclasses.field(repr=False)
|
|
103
106
|
_result_schema: _result.ResultSchema[ResultData] | None = dataclasses.field(repr=False)
|
|
104
107
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = dataclasses.field(repr=False)
|
|
105
|
-
_allow_text_result: bool = dataclasses.field(repr=False)
|
|
106
108
|
_system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
|
|
107
109
|
_function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False)
|
|
108
110
|
_default_retries: int = dataclasses.field(repr=False)
|
|
109
111
|
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False)
|
|
112
|
+
_system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(
|
|
113
|
+
repr=False
|
|
114
|
+
)
|
|
110
115
|
_deps_type: type[AgentDeps] = dataclasses.field(repr=False)
|
|
111
116
|
_max_result_retries: int = dataclasses.field(repr=False)
|
|
112
117
|
_override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False)
|
|
@@ -166,11 +171,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
166
171
|
self.end_strategy = end_strategy
|
|
167
172
|
self.name = name
|
|
168
173
|
self.model_settings = model_settings
|
|
174
|
+
self._result_tool_name = result_tool_name
|
|
175
|
+
self._result_tool_description = result_tool_description
|
|
169
176
|
self._result_schema = _result.ResultSchema[result_type].build(
|
|
170
177
|
result_type, result_tool_name, result_tool_description
|
|
171
178
|
)
|
|
172
|
-
# if the result tool is None, or its schema allows `str`, we allow plain text results
|
|
173
|
-
self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result
|
|
174
179
|
|
|
175
180
|
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
|
|
176
181
|
self._function_tools = {}
|
|
@@ -182,13 +187,31 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
182
187
|
self._register_tool(Tool(tool))
|
|
183
188
|
self._deps_type = deps_type
|
|
184
189
|
self._system_prompt_functions = []
|
|
190
|
+
self._system_prompt_dynamic_functions = {}
|
|
185
191
|
self._max_result_retries = result_retries if result_retries is not None else retries
|
|
186
192
|
self._result_validators = []
|
|
187
193
|
|
|
194
|
+
@overload
|
|
195
|
+
async def run(
|
|
196
|
+
self,
|
|
197
|
+
user_prompt: str,
|
|
198
|
+
*,
|
|
199
|
+
result_type: None = None,
|
|
200
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
201
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
202
|
+
deps: AgentDeps = None,
|
|
203
|
+
model_settings: ModelSettings | None = None,
|
|
204
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
205
|
+
usage: _usage.Usage | None = None,
|
|
206
|
+
infer_name: bool = True,
|
|
207
|
+
) -> result.RunResult[ResultData]: ...
|
|
208
|
+
|
|
209
|
+
@overload
|
|
188
210
|
async def run(
|
|
189
211
|
self,
|
|
190
212
|
user_prompt: str,
|
|
191
213
|
*,
|
|
214
|
+
result_type: type[RunResultData],
|
|
192
215
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
193
216
|
model: models.Model | models.KnownModelName | None = None,
|
|
194
217
|
deps: AgentDeps = None,
|
|
@@ -196,7 +219,21 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
196
219
|
usage_limits: _usage.UsageLimits | None = None,
|
|
197
220
|
usage: _usage.Usage | None = None,
|
|
198
221
|
infer_name: bool = True,
|
|
199
|
-
) -> result.RunResult[
|
|
222
|
+
) -> result.RunResult[RunResultData]: ...
|
|
223
|
+
|
|
224
|
+
async def run(
|
|
225
|
+
self,
|
|
226
|
+
user_prompt: str,
|
|
227
|
+
*,
|
|
228
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
229
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
230
|
+
deps: AgentDeps = None,
|
|
231
|
+
model_settings: ModelSettings | None = None,
|
|
232
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
233
|
+
usage: _usage.Usage | None = None,
|
|
234
|
+
result_type: type[RunResultData] | None = None,
|
|
235
|
+
infer_name: bool = True,
|
|
236
|
+
) -> result.RunResult[Any]:
|
|
200
237
|
"""Run the agent with a user prompt in async mode.
|
|
201
238
|
|
|
202
239
|
Example:
|
|
@@ -211,6 +248,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
211
248
|
```
|
|
212
249
|
|
|
213
250
|
Args:
|
|
251
|
+
result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
|
|
252
|
+
result validators since result validators would expect an argument that matches the agent's result type.
|
|
214
253
|
user_prompt: User input to start/continue the conversation.
|
|
215
254
|
message_history: History of the conversation so far.
|
|
216
255
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
@@ -229,6 +268,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
229
268
|
|
|
230
269
|
deps = self._get_deps(deps)
|
|
231
270
|
new_message_index = len(message_history) if message_history else 0
|
|
271
|
+
result_schema = self._prepare_result_schema(result_type)
|
|
232
272
|
|
|
233
273
|
with _logfire.span(
|
|
234
274
|
'{agent_name} run {prompt=}',
|
|
@@ -252,7 +292,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
252
292
|
|
|
253
293
|
run_context.run_step += 1
|
|
254
294
|
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
|
|
255
|
-
agent_model = await self._prepare_model(run_context)
|
|
295
|
+
agent_model = await self._prepare_model(run_context, result_schema)
|
|
256
296
|
|
|
257
297
|
with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
|
|
258
298
|
model_response, request_usage = await agent_model.request(messages, model_settings)
|
|
@@ -264,7 +304,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
264
304
|
usage_limits.check_tokens(run_context.usage)
|
|
265
305
|
|
|
266
306
|
with _logfire.span('handle model response', run_step=run_context.run_step) as handle_span:
|
|
267
|
-
final_result, tool_responses = await self._handle_model_response(
|
|
307
|
+
final_result, tool_responses = await self._handle_model_response(
|
|
308
|
+
model_response, run_context, result_schema
|
|
309
|
+
)
|
|
268
310
|
|
|
269
311
|
if tool_responses:
|
|
270
312
|
# Add parts to the conversation as a new message
|
|
@@ -287,10 +329,40 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
287
329
|
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
288
330
|
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
289
331
|
|
|
332
|
+
@overload
|
|
333
|
+
def run_sync(
|
|
334
|
+
self,
|
|
335
|
+
user_prompt: str,
|
|
336
|
+
*,
|
|
337
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
338
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
339
|
+
deps: AgentDeps = None,
|
|
340
|
+
model_settings: ModelSettings | None = None,
|
|
341
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
342
|
+
usage: _usage.Usage | None = None,
|
|
343
|
+
infer_name: bool = True,
|
|
344
|
+
) -> result.RunResult[ResultData]: ...
|
|
345
|
+
|
|
346
|
+
@overload
|
|
347
|
+
def run_sync(
|
|
348
|
+
self,
|
|
349
|
+
user_prompt: str,
|
|
350
|
+
*,
|
|
351
|
+
result_type: type[RunResultData] | None,
|
|
352
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
353
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
354
|
+
deps: AgentDeps = None,
|
|
355
|
+
model_settings: ModelSettings | None = None,
|
|
356
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
357
|
+
usage: _usage.Usage | None = None,
|
|
358
|
+
infer_name: bool = True,
|
|
359
|
+
) -> result.RunResult[RunResultData]: ...
|
|
360
|
+
|
|
290
361
|
def run_sync(
|
|
291
362
|
self,
|
|
292
363
|
user_prompt: str,
|
|
293
364
|
*,
|
|
365
|
+
result_type: type[RunResultData] | None = None,
|
|
294
366
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
295
367
|
model: models.Model | models.KnownModelName | None = None,
|
|
296
368
|
deps: AgentDeps = None,
|
|
@@ -298,7 +370,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
298
370
|
usage_limits: _usage.UsageLimits | None = None,
|
|
299
371
|
usage: _usage.Usage | None = None,
|
|
300
372
|
infer_name: bool = True,
|
|
301
|
-
) -> result.RunResult[
|
|
373
|
+
) -> result.RunResult[Any]:
|
|
302
374
|
"""Run the agent with a user prompt synchronously.
|
|
303
375
|
|
|
304
376
|
This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`.
|
|
@@ -317,6 +389,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
317
389
|
```
|
|
318
390
|
|
|
319
391
|
Args:
|
|
392
|
+
result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
|
|
393
|
+
result validators since result validators would expect an argument that matches the agent's result type.
|
|
320
394
|
user_prompt: User input to start/continue the conversation.
|
|
321
395
|
message_history: History of the conversation so far.
|
|
322
396
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
@@ -334,6 +408,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
334
408
|
return asyncio.get_event_loop().run_until_complete(
|
|
335
409
|
self.run(
|
|
336
410
|
user_prompt,
|
|
411
|
+
result_type=result_type,
|
|
337
412
|
message_history=message_history,
|
|
338
413
|
model=model,
|
|
339
414
|
deps=deps,
|
|
@@ -344,11 +419,42 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
344
419
|
)
|
|
345
420
|
)
|
|
346
421
|
|
|
422
|
+
@overload
|
|
423
|
+
def run_stream(
|
|
424
|
+
self,
|
|
425
|
+
user_prompt: str,
|
|
426
|
+
*,
|
|
427
|
+
result_type: None = None,
|
|
428
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
429
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
430
|
+
deps: AgentDeps = None,
|
|
431
|
+
model_settings: ModelSettings | None = None,
|
|
432
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
433
|
+
usage: _usage.Usage | None = None,
|
|
434
|
+
infer_name: bool = True,
|
|
435
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDeps, ResultData]]: ...
|
|
436
|
+
|
|
437
|
+
@overload
|
|
438
|
+
def run_stream(
|
|
439
|
+
self,
|
|
440
|
+
user_prompt: str,
|
|
441
|
+
*,
|
|
442
|
+
result_type: type[RunResultData],
|
|
443
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
444
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
445
|
+
deps: AgentDeps = None,
|
|
446
|
+
model_settings: ModelSettings | None = None,
|
|
447
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
448
|
+
usage: _usage.Usage | None = None,
|
|
449
|
+
infer_name: bool = True,
|
|
450
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDeps, RunResultData]]: ...
|
|
451
|
+
|
|
347
452
|
@asynccontextmanager
|
|
348
453
|
async def run_stream(
|
|
349
454
|
self,
|
|
350
455
|
user_prompt: str,
|
|
351
456
|
*,
|
|
457
|
+
result_type: type[RunResultData] | None = None,
|
|
352
458
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
353
459
|
model: models.Model | models.KnownModelName | None = None,
|
|
354
460
|
deps: AgentDeps = None,
|
|
@@ -356,7 +462,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
356
462
|
usage_limits: _usage.UsageLimits | None = None,
|
|
357
463
|
usage: _usage.Usage | None = None,
|
|
358
464
|
infer_name: bool = True,
|
|
359
|
-
) -> AsyncIterator[result.StreamedRunResult[AgentDeps,
|
|
465
|
+
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, Any]]:
|
|
360
466
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
361
467
|
|
|
362
468
|
Example:
|
|
@@ -372,6 +478,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
372
478
|
```
|
|
373
479
|
|
|
374
480
|
Args:
|
|
481
|
+
result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
|
|
482
|
+
result validators since result validators would expect an argument that matches the agent's result type.
|
|
375
483
|
user_prompt: User input to start/continue the conversation.
|
|
376
484
|
message_history: History of the conversation so far.
|
|
377
485
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
@@ -392,6 +500,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
392
500
|
|
|
393
501
|
deps = self._get_deps(deps)
|
|
394
502
|
new_message_index = len(message_history) if message_history else 0
|
|
503
|
+
result_schema = self._prepare_result_schema(result_type)
|
|
395
504
|
|
|
396
505
|
with _logfire.span(
|
|
397
506
|
'{agent_name} run stream {prompt=}',
|
|
@@ -415,7 +524,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
415
524
|
usage_limits.check_before_request(run_context.usage)
|
|
416
525
|
|
|
417
526
|
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
|
|
418
|
-
agent_model = await self._prepare_model(run_context)
|
|
527
|
+
agent_model = await self._prepare_model(run_context, result_schema)
|
|
419
528
|
|
|
420
529
|
with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span:
|
|
421
530
|
async with agent_model.request_stream(messages, model_settings) as model_response:
|
|
@@ -426,7 +535,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
426
535
|
model_req_span.__exit__(None, None, None)
|
|
427
536
|
|
|
428
537
|
with _logfire.span('handle model response') as handle_span:
|
|
429
|
-
maybe_final_result = await self._handle_streamed_model_response(
|
|
538
|
+
maybe_final_result = await self._handle_streamed_model_response(
|
|
539
|
+
model_response, run_context, result_schema
|
|
540
|
+
)
|
|
430
541
|
|
|
431
542
|
# Check if we got a final result
|
|
432
543
|
if isinstance(maybe_final_result, _MarkFinalResult):
|
|
@@ -446,7 +557,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
446
557
|
part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
|
|
447
558
|
]
|
|
448
559
|
parts = await self._process_function_tools(
|
|
449
|
-
tool_calls, result_tool_name, run_context
|
|
560
|
+
tool_calls, result_tool_name, run_context, result_schema
|
|
450
561
|
)
|
|
451
562
|
if parts:
|
|
452
563
|
messages.append(_messages.ModelRequest(parts))
|
|
@@ -457,7 +568,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
457
568
|
new_message_index,
|
|
458
569
|
usage_limits,
|
|
459
570
|
result_stream,
|
|
460
|
-
|
|
571
|
+
result_schema,
|
|
461
572
|
run_context,
|
|
462
573
|
self._result_validators,
|
|
463
574
|
result_tool_name,
|
|
@@ -535,17 +646,37 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
535
646
|
@overload
|
|
536
647
|
def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...
|
|
537
648
|
|
|
649
|
+
@overload
|
|
650
|
+
def system_prompt(
|
|
651
|
+
self, /, *, dynamic: bool = False
|
|
652
|
+
) -> Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]: ...
|
|
653
|
+
|
|
538
654
|
def system_prompt(
|
|
539
|
-
self,
|
|
540
|
-
|
|
655
|
+
self,
|
|
656
|
+
func: _system_prompt.SystemPromptFunc[AgentDeps] | None = None,
|
|
657
|
+
/,
|
|
658
|
+
*,
|
|
659
|
+
dynamic: bool = False,
|
|
660
|
+
) -> (
|
|
661
|
+
Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]
|
|
662
|
+
| _system_prompt.SystemPromptFunc[AgentDeps]
|
|
663
|
+
):
|
|
541
664
|
"""Decorator to register a system prompt function.
|
|
542
665
|
|
|
543
666
|
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
|
|
544
667
|
Can decorate a sync or async functions.
|
|
545
668
|
|
|
669
|
+
The decorator can be used either bare (`agent.system_prompt`) or as a function call
|
|
670
|
+
(`agent.system_prompt(...)`), see the examples below.
|
|
671
|
+
|
|
546
672
|
Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
|
|
547
673
|
the type of the function, see `tests/typed_agent.py` for tests.
|
|
548
674
|
|
|
675
|
+
Args:
|
|
676
|
+
func: The function to decorate
|
|
677
|
+
dynamic: If True, the system prompt will be reevaluated even when `messages_history` is provided,
|
|
678
|
+
see [`SystemPromptPart.dynamic_ref`][pydantic_ai.messages.SystemPromptPart.dynamic_ref]
|
|
679
|
+
|
|
549
680
|
Example:
|
|
550
681
|
```python
|
|
551
682
|
from pydantic_ai import Agent, RunContext
|
|
@@ -556,17 +687,27 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
556
687
|
def simple_system_prompt() -> str:
|
|
557
688
|
return 'foobar'
|
|
558
689
|
|
|
559
|
-
@agent.system_prompt
|
|
690
|
+
@agent.system_prompt(dynamic=True)
|
|
560
691
|
async def async_system_prompt(ctx: RunContext[str]) -> str:
|
|
561
692
|
return f'{ctx.deps} is the best'
|
|
562
|
-
|
|
563
|
-
result = agent.run_sync('foobar', deps='spam')
|
|
564
|
-
print(result.data)
|
|
565
|
-
#> success (no tool calls)
|
|
566
693
|
```
|
|
567
694
|
"""
|
|
568
|
-
|
|
569
|
-
|
|
695
|
+
if func is None:
|
|
696
|
+
|
|
697
|
+
def decorator(
|
|
698
|
+
func_: _system_prompt.SystemPromptFunc[AgentDeps],
|
|
699
|
+
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
|
|
700
|
+
runner = _system_prompt.SystemPromptRunner(func_, dynamic=dynamic)
|
|
701
|
+
self._system_prompt_functions.append(runner)
|
|
702
|
+
if dynamic:
|
|
703
|
+
self._system_prompt_dynamic_functions[func_.__qualname__] = runner
|
|
704
|
+
return func_
|
|
705
|
+
|
|
706
|
+
return decorator
|
|
707
|
+
else:
|
|
708
|
+
assert not dynamic, "dynamic can't be True in this case"
|
|
709
|
+
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func, dynamic=dynamic))
|
|
710
|
+
return func
|
|
570
711
|
|
|
571
712
|
@overload
|
|
572
713
|
def result_validator(
|
|
@@ -818,7 +959,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
818
959
|
|
|
819
960
|
return model_
|
|
820
961
|
|
|
821
|
-
async def _prepare_model(
|
|
962
|
+
async def _prepare_model(
|
|
963
|
+
self, run_context: RunContext[AgentDeps], result_schema: _result.ResultSchema[RunResultData] | None
|
|
964
|
+
) -> models.AgentModel:
|
|
822
965
|
"""Build tools and create an agent model."""
|
|
823
966
|
function_tools: list[ToolDefinition] = []
|
|
824
967
|
|
|
@@ -831,10 +974,39 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
831
974
|
|
|
832
975
|
return await run_context.model.agent_model(
|
|
833
976
|
function_tools=function_tools,
|
|
834
|
-
allow_text_result=self._allow_text_result,
|
|
835
|
-
result_tools=
|
|
977
|
+
allow_text_result=self._allow_text_result(result_schema),
|
|
978
|
+
result_tools=result_schema.tool_defs() if result_schema is not None else [],
|
|
836
979
|
)
|
|
837
980
|
|
|
981
|
+
async def _reevaluate_dynamic_prompts(
|
|
982
|
+
self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDeps]
|
|
983
|
+
) -> None:
|
|
984
|
+
"""Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
|
|
985
|
+
# Only proceed if there's at least one dynamic runner.
|
|
986
|
+
if self._system_prompt_dynamic_functions:
|
|
987
|
+
for msg in messages:
|
|
988
|
+
if isinstance(msg, _messages.ModelRequest):
|
|
989
|
+
for i, part in enumerate(msg.parts):
|
|
990
|
+
if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
|
|
991
|
+
# Look up the runner by its ref
|
|
992
|
+
if runner := self._system_prompt_dynamic_functions.get(part.dynamic_ref):
|
|
993
|
+
updated_part_content = await runner.run(run_context)
|
|
994
|
+
msg.parts[i] = _messages.SystemPromptPart(
|
|
995
|
+
updated_part_content, dynamic_ref=part.dynamic_ref
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
def _prepare_result_schema(
|
|
999
|
+
self, result_type: type[RunResultData] | None
|
|
1000
|
+
) -> _result.ResultSchema[RunResultData] | None:
|
|
1001
|
+
if result_type is not None:
|
|
1002
|
+
if self._result_validators:
|
|
1003
|
+
raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
|
|
1004
|
+
return _result.ResultSchema[result_type].build(
|
|
1005
|
+
result_type, self._result_tool_name, self._result_tool_description
|
|
1006
|
+
)
|
|
1007
|
+
else:
|
|
1008
|
+
return self._result_schema # pyright: ignore[reportReturnType]
|
|
1009
|
+
|
|
838
1010
|
async def _prepare_messages(
|
|
839
1011
|
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
|
|
840
1012
|
) -> list[_messages.ModelMessage]:
|
|
@@ -850,8 +1022,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
850
1022
|
ctx_messages.used = True
|
|
851
1023
|
|
|
852
1024
|
if message_history:
|
|
853
|
-
#
|
|
1025
|
+
# Shallow copy messages
|
|
854
1026
|
messages.extend(message_history)
|
|
1027
|
+
# Reevaluate any dynamic system prompt parts
|
|
1028
|
+
await self._reevaluate_dynamic_prompts(messages, run_context)
|
|
855
1029
|
messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
|
|
856
1030
|
else:
|
|
857
1031
|
parts = await self._sys_parts(run_context)
|
|
@@ -861,8 +1035,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
861
1035
|
return messages
|
|
862
1036
|
|
|
863
1037
|
async def _handle_model_response(
|
|
864
|
-
self,
|
|
865
|
-
|
|
1038
|
+
self,
|
|
1039
|
+
model_response: _messages.ModelResponse,
|
|
1040
|
+
run_context: RunContext[AgentDeps],
|
|
1041
|
+
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1042
|
+
) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
|
|
866
1043
|
"""Process a non-streamed response from the model.
|
|
867
1044
|
|
|
868
1045
|
Returns:
|
|
@@ -883,19 +1060,19 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
883
1060
|
# This accounts for cases like anthropic returns that might contain a text response
|
|
884
1061
|
# and a tool call response, where the text response just indicates the tool call will happen.
|
|
885
1062
|
if tool_calls:
|
|
886
|
-
return await self._handle_structured_response(tool_calls, run_context)
|
|
1063
|
+
return await self._handle_structured_response(tool_calls, run_context, result_schema)
|
|
887
1064
|
elif texts:
|
|
888
1065
|
text = '\n\n'.join(texts)
|
|
889
|
-
return await self._handle_text_response(text, run_context)
|
|
1066
|
+
return await self._handle_text_response(text, run_context, result_schema)
|
|
890
1067
|
else:
|
|
891
1068
|
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
892
1069
|
|
|
893
1070
|
async def _handle_text_response(
|
|
894
|
-
self, text: str, run_context: RunContext[AgentDeps]
|
|
895
|
-
) -> tuple[_MarkFinalResult[
|
|
1071
|
+
self, text: str, run_context: RunContext[AgentDeps], result_schema: _result.ResultSchema[RunResultData] | None
|
|
1072
|
+
) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
|
|
896
1073
|
"""Handle a plain text response from the model for non-streaming responses."""
|
|
897
|
-
if self._allow_text_result:
|
|
898
|
-
result_data_input = cast(
|
|
1074
|
+
if self._allow_text_result(result_schema):
|
|
1075
|
+
result_data_input = cast(RunResultData, text)
|
|
899
1076
|
try:
|
|
900
1077
|
result_data = await self._validate_result(result_data_input, run_context, None)
|
|
901
1078
|
except _result.ToolRetryError as e:
|
|
@@ -911,16 +1088,19 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
911
1088
|
return None, [response]
|
|
912
1089
|
|
|
913
1090
|
async def _handle_structured_response(
|
|
914
|
-
self,
|
|
915
|
-
|
|
1091
|
+
self,
|
|
1092
|
+
tool_calls: list[_messages.ToolCallPart],
|
|
1093
|
+
run_context: RunContext[AgentDeps],
|
|
1094
|
+
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1095
|
+
) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
|
|
916
1096
|
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
|
|
917
1097
|
assert tool_calls, 'Expected at least one tool call'
|
|
918
1098
|
|
|
919
1099
|
# first look for the result tool call
|
|
920
|
-
final_result: _MarkFinalResult[
|
|
1100
|
+
final_result: _MarkFinalResult[RunResultData] | None = None
|
|
921
1101
|
|
|
922
1102
|
parts: list[_messages.ModelRequestPart] = []
|
|
923
|
-
if result_schema :=
|
|
1103
|
+
if result_schema := result_schema:
|
|
924
1104
|
if match := result_schema.find_tool(tool_calls):
|
|
925
1105
|
call, result_tool = match
|
|
926
1106
|
try:
|
|
@@ -933,7 +1113,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
933
1113
|
final_result = _MarkFinalResult(result_data, call.tool_name)
|
|
934
1114
|
|
|
935
1115
|
# Then build the other request parts based on end strategy
|
|
936
|
-
parts += await self._process_function_tools(
|
|
1116
|
+
parts += await self._process_function_tools(
|
|
1117
|
+
tool_calls, final_result and final_result.tool_name, run_context, result_schema
|
|
1118
|
+
)
|
|
937
1119
|
|
|
938
1120
|
return final_result, parts
|
|
939
1121
|
|
|
@@ -942,6 +1124,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
942
1124
|
tool_calls: list[_messages.ToolCallPart],
|
|
943
1125
|
result_tool_name: str | None,
|
|
944
1126
|
run_context: RunContext[AgentDeps],
|
|
1127
|
+
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
945
1128
|
) -> list[_messages.ModelRequestPart]:
|
|
946
1129
|
"""Process function (non-result) tool calls in parallel.
|
|
947
1130
|
|
|
@@ -975,7 +1158,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
975
1158
|
)
|
|
976
1159
|
else:
|
|
977
1160
|
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
|
|
978
|
-
elif
|
|
1161
|
+
elif result_schema is not None and call.tool_name in result_schema.tools:
|
|
979
1162
|
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
|
|
980
1163
|
# validation, we don't add another part here
|
|
981
1164
|
if result_tool_name is not None:
|
|
@@ -987,7 +1170,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
987
1170
|
)
|
|
988
1171
|
)
|
|
989
1172
|
else:
|
|
990
|
-
parts.append(self._unknown_tool(call.tool_name, run_context))
|
|
1173
|
+
parts.append(self._unknown_tool(call.tool_name, run_context, result_schema))
|
|
991
1174
|
|
|
992
1175
|
# Run all tool tasks in parallel
|
|
993
1176
|
if tasks:
|
|
@@ -1000,6 +1183,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1000
1183
|
self,
|
|
1001
1184
|
model_response: models.EitherStreamedResponse,
|
|
1002
1185
|
run_context: RunContext[AgentDeps],
|
|
1186
|
+
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1003
1187
|
) -> (
|
|
1004
1188
|
_MarkFinalResult[models.EitherStreamedResponse]
|
|
1005
1189
|
| tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
|
|
@@ -1012,7 +1196,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1012
1196
|
"""
|
|
1013
1197
|
if isinstance(model_response, models.StreamTextResponse):
|
|
1014
1198
|
# plain string response
|
|
1015
|
-
if self._allow_text_result:
|
|
1199
|
+
if self._allow_text_result(result_schema):
|
|
1016
1200
|
return _MarkFinalResult(model_response, None)
|
|
1017
1201
|
else:
|
|
1018
1202
|
self._incr_result_retry(run_context)
|
|
@@ -1026,7 +1210,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1026
1210
|
text = ''.join(model_response.get(final=True))
|
|
1027
1211
|
return _messages.ModelResponse([_messages.TextPart(text)]), [response]
|
|
1028
1212
|
elif isinstance(model_response, models.StreamStructuredResponse):
|
|
1029
|
-
if
|
|
1213
|
+
if result_schema is not None:
|
|
1030
1214
|
# if there's a result schema, iterate over the stream until we find at least one tool
|
|
1031
1215
|
# NOTE: this means we ignore any other tools called here
|
|
1032
1216
|
structured_msg = model_response.get()
|
|
@@ -1037,7 +1221,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1037
1221
|
break
|
|
1038
1222
|
structured_msg = model_response.get()
|
|
1039
1223
|
|
|
1040
|
-
if match :=
|
|
1224
|
+
if match := result_schema.find_tool(structured_msg.parts):
|
|
1041
1225
|
call, _ = match
|
|
1042
1226
|
return _MarkFinalResult(model_response, call.tool_name)
|
|
1043
1227
|
|
|
@@ -1057,7 +1241,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1057
1241
|
if tool := self._function_tools.get(call.tool_name):
|
|
1058
1242
|
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
|
|
1059
1243
|
else:
|
|
1060
|
-
parts.append(self._unknown_tool(call.tool_name, run_context))
|
|
1244
|
+
parts.append(self._unknown_tool(call.tool_name, run_context, result_schema))
|
|
1061
1245
|
|
|
1062
1246
|
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
1063
1247
|
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
@@ -1068,13 +1252,17 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1068
1252
|
|
|
1069
1253
|
async def _validate_result(
|
|
1070
1254
|
self,
|
|
1071
|
-
result_data:
|
|
1255
|
+
result_data: RunResultData,
|
|
1072
1256
|
run_context: RunContext[AgentDeps],
|
|
1073
1257
|
tool_call: _messages.ToolCallPart | None,
|
|
1074
|
-
) ->
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1258
|
+
) -> RunResultData:
|
|
1259
|
+
if self._result_validators:
|
|
1260
|
+
agent_result_data = cast(ResultData, result_data)
|
|
1261
|
+
for validator in self._result_validators:
|
|
1262
|
+
agent_result_data = await validator.validate(agent_result_data, tool_call, run_context)
|
|
1263
|
+
return cast(RunResultData, agent_result_data)
|
|
1264
|
+
else:
|
|
1265
|
+
return result_data
|
|
1078
1266
|
|
|
1079
1267
|
def _incr_result_retry(self, run_context: RunContext[AgentDeps]) -> None:
|
|
1080
1268
|
run_context.retry += 1
|
|
@@ -1088,14 +1276,22 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1088
1276
|
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
|
|
1089
1277
|
for sys_prompt_runner in self._system_prompt_functions:
|
|
1090
1278
|
prompt = await sys_prompt_runner.run(run_context)
|
|
1091
|
-
|
|
1279
|
+
if sys_prompt_runner.dynamic:
|
|
1280
|
+
messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
|
|
1281
|
+
else:
|
|
1282
|
+
messages.append(_messages.SystemPromptPart(prompt))
|
|
1092
1283
|
return messages
|
|
1093
1284
|
|
|
1094
|
-
def _unknown_tool(
|
|
1285
|
+
def _unknown_tool(
|
|
1286
|
+
self,
|
|
1287
|
+
tool_name: str,
|
|
1288
|
+
run_context: RunContext[AgentDeps],
|
|
1289
|
+
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1290
|
+
) -> _messages.RetryPromptPart:
|
|
1095
1291
|
self._incr_result_retry(run_context)
|
|
1096
1292
|
names = list(self._function_tools.keys())
|
|
1097
|
-
if
|
|
1098
|
-
names.extend(
|
|
1293
|
+
if result_schema:
|
|
1294
|
+
names.extend(result_schema.tool_names())
|
|
1099
1295
|
if names:
|
|
1100
1296
|
msg = f'Available tools: {", ".join(names)}'
|
|
1101
1297
|
else:
|
|
@@ -1133,6 +1329,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1133
1329
|
self.name = name
|
|
1134
1330
|
return
|
|
1135
1331
|
|
|
1332
|
+
@staticmethod
|
|
1333
|
+
def _allow_text_result(result_schema: _result.ResultSchema[RunResultData] | None) -> bool:
|
|
1334
|
+
return result_schema is None or result_schema.allow_text_result
|
|
1335
|
+
|
|
1136
1336
|
@property
|
|
1137
1337
|
@deprecated(
|
|
1138
1338
|
'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None
|
|
@@ -21,6 +21,12 @@ class SystemPromptPart:
|
|
|
21
21
|
content: str
|
|
22
22
|
"""The content of the prompt."""
|
|
23
23
|
|
|
24
|
+
dynamic_ref: str | None = None
|
|
25
|
+
"""The ref of the dynamic system prompt function that generated this part.
|
|
26
|
+
|
|
27
|
+
Only set if system prompt is dynamic, see [`system_prompt`][pydantic_ai.Agent.system_prompt] for more information.
|
|
28
|
+
"""
|
|
29
|
+
|
|
24
30
|
part_kind: Literal['system-prompt'] = 'system-prompt'
|
|
25
31
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
26
32
|
|
|
@@ -48,13 +48,12 @@ KnownModelName = Literal[
|
|
|
48
48
|
'groq:mixtral-8x7b-32768',
|
|
49
49
|
'groq:gemma2-9b-it',
|
|
50
50
|
'groq:gemma-7b-it',
|
|
51
|
-
'gemini-1.5-flash',
|
|
52
|
-
'gemini-1.5-pro',
|
|
53
|
-
'gemini-2.0-flash-exp',
|
|
54
|
-
'
|
|
55
|
-
'
|
|
56
|
-
|
|
57
|
-
# don't start with "mistral", we add the "mistral:" prefix to all to be explicit
|
|
51
|
+
'google-gla:gemini-1.5-flash',
|
|
52
|
+
'google-gla:gemini-1.5-pro',
|
|
53
|
+
'google-gla:gemini-2.0-flash-exp',
|
|
54
|
+
'google-vertex:gemini-1.5-flash',
|
|
55
|
+
'google-vertex:gemini-1.5-pro',
|
|
56
|
+
'google-vertex:gemini-2.0-flash-exp',
|
|
58
57
|
'mistral:mistral-small-latest',
|
|
59
58
|
'mistral:mistral-large-latest',
|
|
60
59
|
'mistral:codestral-latest',
|
|
@@ -76,9 +75,9 @@ KnownModelName = Literal[
|
|
|
76
75
|
'ollama:qwen2',
|
|
77
76
|
'ollama:qwen2.5',
|
|
78
77
|
'ollama:starcoder2',
|
|
79
|
-
'claude-3-5-haiku-latest',
|
|
80
|
-
'claude-3-5-sonnet-latest',
|
|
81
|
-
'claude-3-opus-latest',
|
|
78
|
+
'anthropic:claude-3-5-haiku-latest',
|
|
79
|
+
'anthropic:claude-3-5-sonnet-latest',
|
|
80
|
+
'anthropic:claude-3-opus-latest',
|
|
82
81
|
'test',
|
|
83
82
|
]
|
|
84
83
|
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
|
|
@@ -274,6 +273,15 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
274
273
|
from .openai import OpenAIModel
|
|
275
274
|
|
|
276
275
|
return OpenAIModel(model[7:])
|
|
276
|
+
elif model.startswith(('gpt', 'o1')):
|
|
277
|
+
from .openai import OpenAIModel
|
|
278
|
+
|
|
279
|
+
return OpenAIModel(model)
|
|
280
|
+
elif model.startswith('google-gla'):
|
|
281
|
+
from .gemini import GeminiModel
|
|
282
|
+
|
|
283
|
+
return GeminiModel(model[11:]) # pyright: ignore[reportArgumentType]
|
|
284
|
+
# backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
|
|
277
285
|
elif model.startswith('gemini'):
|
|
278
286
|
from .gemini import GeminiModel
|
|
279
287
|
|
|
@@ -283,6 +291,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
283
291
|
from .groq import GroqModel
|
|
284
292
|
|
|
285
293
|
return GroqModel(model[5:]) # pyright: ignore[reportArgumentType]
|
|
294
|
+
elif model.startswith('google-vertex'):
|
|
295
|
+
from .vertexai import VertexAIModel
|
|
296
|
+
|
|
297
|
+
return VertexAIModel(model[14:]) # pyright: ignore[reportArgumentType]
|
|
298
|
+
# backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
|
|
286
299
|
elif model.startswith('vertexai:'):
|
|
287
300
|
from .vertexai import VertexAIModel
|
|
288
301
|
|
|
@@ -295,6 +308,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
295
308
|
from .ollama import OllamaModel
|
|
296
309
|
|
|
297
310
|
return OllamaModel(model[7:])
|
|
311
|
+
elif model.startswith('anthropic'):
|
|
312
|
+
from .anthropic import AnthropicModel
|
|
313
|
+
|
|
314
|
+
return AnthropicModel(model[10:])
|
|
315
|
+
# backwards compatibility with old model names (ex, claude-3-5-sonnet-latest -> anthropic:claude-3-5-sonnet-latest)
|
|
298
316
|
elif model.startswith('claude'):
|
|
299
317
|
from .anthropic import AnthropicModel
|
|
300
318
|
|
|
@@ -111,7 +111,7 @@ class GeminiModel(Model):
|
|
|
111
111
|
)
|
|
112
112
|
|
|
113
113
|
def name(self) -> str:
|
|
114
|
-
return self.model_name
|
|
114
|
+
return f'google-gla:{self.model_name}'
|
|
115
115
|
|
|
116
116
|
|
|
117
117
|
class AuthProtocol(Protocol):
|
|
@@ -693,7 +693,7 @@ class _GeminiJsonSchema:
|
|
|
693
693
|
|
|
694
694
|
def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
|
|
695
695
|
schema.pop('title', None)
|
|
696
|
-
|
|
696
|
+
schema.pop('default', None)
|
|
697
697
|
if ref := schema.pop('$ref', None):
|
|
698
698
|
# noinspection PyTypeChecker
|
|
699
699
|
key = re.sub(r'^#/\$defs/', '', ref)
|
|
@@ -708,11 +708,12 @@ class _GeminiJsonSchema:
|
|
|
708
708
|
if any_of := schema.get('anyOf'):
|
|
709
709
|
for item_schema in any_of:
|
|
710
710
|
self._simplify(item_schema, refs_stack)
|
|
711
|
-
if len(any_of) == 2 and {'type': 'null'} in any_of
|
|
711
|
+
if len(any_of) == 2 and {'type': 'null'} in any_of:
|
|
712
712
|
for item_schema in any_of:
|
|
713
713
|
if item_schema != {'type': 'null'}:
|
|
714
714
|
schema.clear()
|
|
715
715
|
schema.update(item_schema)
|
|
716
|
+
schema['nullable'] = True
|
|
716
717
|
return
|
|
717
718
|
|
|
718
719
|
type_ = schema.get('type')
|
|
@@ -721,6 +722,12 @@ class _GeminiJsonSchema:
|
|
|
721
722
|
self._object(schema, refs_stack)
|
|
722
723
|
elif type_ == 'array':
|
|
723
724
|
return self._array(schema, refs_stack)
|
|
725
|
+
elif type_ == 'string' and (fmt := schema.pop('format', None)):
|
|
726
|
+
description = schema.get('description')
|
|
727
|
+
if description:
|
|
728
|
+
schema['description'] = f'{description} (format: {fmt})'
|
|
729
|
+
else:
|
|
730
|
+
schema['description'] = f'Format: {fmt}'
|
|
724
731
|
|
|
725
732
|
def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
|
|
726
733
|
ad_props = schema.pop('additionalProperties', None)
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "pydantic-ai-slim"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.18"
|
|
8
8
|
description = "Agent Framework / shim to use Pydantic with LLMs, slim package"
|
|
9
9
|
authors = [
|
|
10
10
|
{ name = "Samuel Colvin", email = "samuel@pydantic.dev" },
|
|
@@ -60,6 +60,7 @@ dev = [
|
|
|
60
60
|
"pytest-examples>=0.0.14",
|
|
61
61
|
"pytest-mock>=3.14.0",
|
|
62
62
|
"pytest-pretty>=1.2.0",
|
|
63
|
+
"diff-cover>=9.2.0",
|
|
63
64
|
]
|
|
64
65
|
|
|
65
66
|
[tool.hatch.build.targets.wheel]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|