pydantic-ai-slim 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +2 -1
- pydantic_ai/_griffe.py +1 -2
- pydantic_ai/_result.py +2 -2
- pydantic_ai/agent.py +130 -65
- pydantic_ai/models/gemini.py +11 -4
- pydantic_ai/models/mistral.py +10 -15
- pydantic_ai/models/ollama.py +4 -1
- pydantic_ai/models/test.py +18 -7
- pydantic_ai/result.py +43 -19
- pydantic_ai/settings.py +5 -1
- pydantic_ai/tools.py +16 -23
- {pydantic_ai_slim-0.0.14.dist-info → pydantic_ai_slim-0.0.16.dist-info}/METADATA +1 -2
- pydantic_ai_slim-0.0.16.dist-info/RECORD +26 -0
- pydantic_ai_slim-0.0.14.dist-info/RECORD +0 -26
- {pydantic_ai_slim-0.0.14.dist-info → pydantic_ai_slim-0.0.16.dist-info}/WHEEL +0 -0
pydantic_ai/__init__.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
from importlib.metadata import version
|
|
2
2
|
|
|
3
|
-
from .agent import Agent
|
|
3
|
+
from .agent import Agent, capture_run_messages
|
|
4
4
|
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
|
|
5
5
|
from .tools import RunContext, Tool
|
|
6
6
|
|
|
7
7
|
__all__ = (
|
|
8
8
|
'Agent',
|
|
9
|
+
'capture_run_messages',
|
|
9
10
|
'RunContext',
|
|
10
11
|
'Tool',
|
|
11
12
|
'AgentRunError',
|
pydantic_ai/_griffe.py
CHANGED
|
@@ -4,8 +4,7 @@ import re
|
|
|
4
4
|
from inspect import Signature
|
|
5
5
|
from typing import Any, Callable, Literal, cast
|
|
6
6
|
|
|
7
|
-
from
|
|
8
|
-
from _griffe.models import Docstring, Object as GriffeObject
|
|
7
|
+
from griffe import Docstring, DocstringSectionKind, Object as GriffeObject
|
|
9
8
|
|
|
10
9
|
DocstringStyle = Literal['google', 'numpy', 'sphinx']
|
|
11
10
|
|
pydantic_ai/_result.py
CHANGED
|
@@ -12,8 +12,8 @@ from typing_extensions import Self, TypeAliasType, TypedDict
|
|
|
12
12
|
|
|
13
13
|
from . import _utils, messages as _messages
|
|
14
14
|
from .exceptions import ModelRetry
|
|
15
|
-
from .result import ResultData
|
|
16
|
-
from .tools import AgentDeps,
|
|
15
|
+
from .result import ResultData, ResultValidatorFunc
|
|
16
|
+
from .tools import AgentDeps, RunContext, ToolDefinition
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
@dataclass
|
pydantic_ai/agent.py
CHANGED
|
@@ -5,12 +5,12 @@ import dataclasses
|
|
|
5
5
|
import inspect
|
|
6
6
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
7
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
8
|
-
from
|
|
8
|
+
from contextvars import ContextVar
|
|
9
9
|
from types import FrameType
|
|
10
10
|
from typing import Any, Callable, Generic, Literal, cast, final, overload
|
|
11
11
|
|
|
12
12
|
import logfire_api
|
|
13
|
-
from typing_extensions import assert_never
|
|
13
|
+
from typing_extensions import assert_never, deprecated
|
|
14
14
|
|
|
15
15
|
from . import (
|
|
16
16
|
_result,
|
|
@@ -35,10 +35,20 @@ from .tools import (
|
|
|
35
35
|
ToolPrepareFunc,
|
|
36
36
|
)
|
|
37
37
|
|
|
38
|
-
__all__ =
|
|
38
|
+
__all__ = 'Agent', 'capture_run_messages', 'EndStrategy'
|
|
39
39
|
|
|
40
40
|
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
41
41
|
|
|
42
|
+
# while waiting for https://github.com/pydantic/logfire/issues/745
|
|
43
|
+
try:
|
|
44
|
+
import logfire._internal.stack_info
|
|
45
|
+
except ImportError:
|
|
46
|
+
pass
|
|
47
|
+
else:
|
|
48
|
+
from pathlib import Path
|
|
49
|
+
|
|
50
|
+
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
|
|
51
|
+
|
|
42
52
|
NoneType = type(None)
|
|
43
53
|
EndStrategy = Literal['early', 'exhaustive']
|
|
44
54
|
"""The strategy for handling multiple tool calls when a final result is found.
|
|
@@ -49,7 +59,7 @@ EndStrategy = Literal['early', 'exhaustive']
|
|
|
49
59
|
|
|
50
60
|
|
|
51
61
|
@final
|
|
52
|
-
@dataclass(init=False)
|
|
62
|
+
@dataclasses.dataclass(init=False)
|
|
53
63
|
class Agent(Generic[AgentDeps, ResultData]):
|
|
54
64
|
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
|
|
55
65
|
|
|
@@ -89,23 +99,17 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
89
99
|
be merged with this value, with the runtime argument taking priority.
|
|
90
100
|
"""
|
|
91
101
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
_default_retries: int = field(repr=False)
|
|
104
|
-
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
|
|
105
|
-
_deps_type: type[AgentDeps] = field(repr=False)
|
|
106
|
-
_max_result_retries: int = field(repr=False)
|
|
107
|
-
_override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
|
|
108
|
-
_override_model: _utils.Option[models.Model] = field(default=None, repr=False)
|
|
102
|
+
_result_schema: _result.ResultSchema[ResultData] | None = dataclasses.field(repr=False)
|
|
103
|
+
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = dataclasses.field(repr=False)
|
|
104
|
+
_allow_text_result: bool = dataclasses.field(repr=False)
|
|
105
|
+
_system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
|
|
106
|
+
_function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False)
|
|
107
|
+
_default_retries: int = dataclasses.field(repr=False)
|
|
108
|
+
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False)
|
|
109
|
+
_deps_type: type[AgentDeps] = dataclasses.field(repr=False)
|
|
110
|
+
_max_result_retries: int = dataclasses.field(repr=False)
|
|
111
|
+
_override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False)
|
|
112
|
+
_override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
|
|
109
113
|
|
|
110
114
|
def __init__(
|
|
111
115
|
self,
|
|
@@ -161,7 +165,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
161
165
|
self.end_strategy = end_strategy
|
|
162
166
|
self.name = name
|
|
163
167
|
self.model_settings = model_settings
|
|
164
|
-
self.last_run_messages = None
|
|
165
168
|
self._result_schema = _result.ResultSchema[result_type].build(
|
|
166
169
|
result_type, result_tool_name, result_tool_description
|
|
167
170
|
)
|
|
@@ -190,6 +193,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
190
193
|
deps: AgentDeps = None,
|
|
191
194
|
model_settings: ModelSettings | None = None,
|
|
192
195
|
usage_limits: UsageLimits | None = None,
|
|
196
|
+
usage: result.Usage | None = None,
|
|
193
197
|
infer_name: bool = True,
|
|
194
198
|
) -> result.RunResult[ResultData]:
|
|
195
199
|
"""Run the agent with a user prompt in async mode.
|
|
@@ -212,6 +216,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
212
216
|
deps: Optional dependencies to use for this run.
|
|
213
217
|
model_settings: Optional settings to use for this model's request.
|
|
214
218
|
usage_limits: Optional limits on model request count or token usage.
|
|
219
|
+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
215
220
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
216
221
|
|
|
217
222
|
Returns:
|
|
@@ -219,7 +224,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
219
224
|
"""
|
|
220
225
|
if infer_name and self.name is None:
|
|
221
226
|
self._infer_name(inspect.currentframe())
|
|
222
|
-
model_used
|
|
227
|
+
model_used = await self._get_model(model)
|
|
223
228
|
|
|
224
229
|
deps = self._get_deps(deps)
|
|
225
230
|
new_message_index = len(message_history) if message_history else 0
|
|
@@ -228,40 +233,36 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
228
233
|
'{agent_name} run {prompt=}',
|
|
229
234
|
prompt=user_prompt,
|
|
230
235
|
agent=self,
|
|
231
|
-
mode_selection=mode_selection,
|
|
232
236
|
model_name=model_used.name(),
|
|
233
237
|
agent_name=self.name or 'agent',
|
|
234
238
|
) as run_span:
|
|
235
|
-
run_context = RunContext(deps,
|
|
239
|
+
run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
|
|
236
240
|
messages = await self._prepare_messages(user_prompt, message_history, run_context)
|
|
237
|
-
|
|
241
|
+
run_context.messages = messages
|
|
238
242
|
|
|
239
243
|
for tool in self._function_tools.values():
|
|
240
244
|
tool.current_retry = 0
|
|
241
245
|
|
|
242
|
-
usage = result.Usage(requests=0)
|
|
243
246
|
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
244
247
|
usage_limits = usage_limits or UsageLimits()
|
|
245
248
|
|
|
246
|
-
run_step = 0
|
|
247
249
|
while True:
|
|
248
|
-
usage_limits.check_before_request(usage)
|
|
250
|
+
usage_limits.check_before_request(run_context.usage)
|
|
249
251
|
|
|
250
|
-
run_step += 1
|
|
251
|
-
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
252
|
+
run_context.run_step += 1
|
|
253
|
+
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
|
|
252
254
|
agent_model = await self._prepare_model(run_context)
|
|
253
255
|
|
|
254
|
-
with _logfire.span('model request', run_step=run_step) as model_req_span:
|
|
256
|
+
with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
|
|
255
257
|
model_response, request_usage = await agent_model.request(messages, model_settings)
|
|
256
258
|
model_req_span.set_attribute('response', model_response)
|
|
257
259
|
model_req_span.set_attribute('usage', request_usage)
|
|
258
260
|
|
|
259
261
|
messages.append(model_response)
|
|
260
|
-
usage
|
|
261
|
-
usage
|
|
262
|
-
usage_limits.check_tokens(request_usage)
|
|
262
|
+
run_context.usage.incr(request_usage, requests=1)
|
|
263
|
+
usage_limits.check_tokens(run_context.usage)
|
|
263
264
|
|
|
264
|
-
with _logfire.span('handle model response', run_step=run_step) as handle_span:
|
|
265
|
+
with _logfire.span('handle model response', run_step=run_context.run_step) as handle_span:
|
|
265
266
|
final_result, tool_responses = await self._handle_model_response(model_response, run_context)
|
|
266
267
|
|
|
267
268
|
if tool_responses:
|
|
@@ -272,10 +273,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
272
273
|
if final_result is not None:
|
|
273
274
|
result_data = final_result.data
|
|
274
275
|
run_span.set_attribute('all_messages', messages)
|
|
275
|
-
run_span.set_attribute('usage', usage)
|
|
276
|
+
run_span.set_attribute('usage', run_context.usage)
|
|
276
277
|
handle_span.set_attribute('result', result_data)
|
|
277
278
|
handle_span.message = 'handle model response -> final result'
|
|
278
|
-
return result.RunResult(messages, new_message_index, result_data, usage)
|
|
279
|
+
return result.RunResult(messages, new_message_index, result_data, run_context.usage)
|
|
279
280
|
else:
|
|
280
281
|
# continue the conversation
|
|
281
282
|
handle_span.set_attribute('tool_responses', tool_responses)
|
|
@@ -291,6 +292,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
291
292
|
deps: AgentDeps = None,
|
|
292
293
|
model_settings: ModelSettings | None = None,
|
|
293
294
|
usage_limits: UsageLimits | None = None,
|
|
295
|
+
usage: result.Usage | None = None,
|
|
294
296
|
infer_name: bool = True,
|
|
295
297
|
) -> result.RunResult[ResultData]:
|
|
296
298
|
"""Run the agent with a user prompt synchronously.
|
|
@@ -317,6 +319,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
317
319
|
deps: Optional dependencies to use for this run.
|
|
318
320
|
model_settings: Optional settings to use for this model's request.
|
|
319
321
|
usage_limits: Optional limits on model request count or token usage.
|
|
322
|
+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
320
323
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
321
324
|
|
|
322
325
|
Returns:
|
|
@@ -332,6 +335,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
332
335
|
deps=deps,
|
|
333
336
|
model_settings=model_settings,
|
|
334
337
|
usage_limits=usage_limits,
|
|
338
|
+
usage=usage,
|
|
335
339
|
infer_name=False,
|
|
336
340
|
)
|
|
337
341
|
)
|
|
@@ -346,6 +350,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
346
350
|
deps: AgentDeps = None,
|
|
347
351
|
model_settings: ModelSettings | None = None,
|
|
348
352
|
usage_limits: UsageLimits | None = None,
|
|
353
|
+
usage: result.Usage | None = None,
|
|
349
354
|
infer_name: bool = True,
|
|
350
355
|
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
|
|
351
356
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
@@ -369,6 +374,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
369
374
|
deps: Optional dependencies to use for this run.
|
|
370
375
|
model_settings: Optional settings to use for this model's request.
|
|
371
376
|
usage_limits: Optional limits on model request count or token usage.
|
|
377
|
+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
372
378
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
373
379
|
|
|
374
380
|
Returns:
|
|
@@ -378,7 +384,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
378
384
|
# f_back because `asynccontextmanager` adds one frame
|
|
379
385
|
if frame := inspect.currentframe(): # pragma: no branch
|
|
380
386
|
self._infer_name(frame.f_back)
|
|
381
|
-
model_used
|
|
387
|
+
model_used = await self._get_model(model)
|
|
382
388
|
|
|
383
389
|
deps = self._get_deps(deps)
|
|
384
390
|
new_message_index = len(message_history) if message_history else 0
|
|
@@ -387,32 +393,29 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
387
393
|
'{agent_name} run stream {prompt=}',
|
|
388
394
|
prompt=user_prompt,
|
|
389
395
|
agent=self,
|
|
390
|
-
mode_selection=mode_selection,
|
|
391
396
|
model_name=model_used.name(),
|
|
392
397
|
agent_name=self.name or 'agent',
|
|
393
398
|
) as run_span:
|
|
394
|
-
run_context = RunContext(deps,
|
|
399
|
+
run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
|
|
395
400
|
messages = await self._prepare_messages(user_prompt, message_history, run_context)
|
|
396
|
-
|
|
401
|
+
run_context.messages = messages
|
|
397
402
|
|
|
398
403
|
for tool in self._function_tools.values():
|
|
399
404
|
tool.current_retry = 0
|
|
400
405
|
|
|
401
|
-
usage = result.Usage()
|
|
402
406
|
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
403
407
|
usage_limits = usage_limits or UsageLimits()
|
|
404
408
|
|
|
405
|
-
run_step = 0
|
|
406
409
|
while True:
|
|
407
|
-
run_step += 1
|
|
408
|
-
usage_limits.check_before_request(usage)
|
|
410
|
+
run_context.run_step += 1
|
|
411
|
+
usage_limits.check_before_request(run_context.usage)
|
|
409
412
|
|
|
410
|
-
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
413
|
+
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
|
|
411
414
|
agent_model = await self._prepare_model(run_context)
|
|
412
415
|
|
|
413
|
-
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
|
|
416
|
+
with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span:
|
|
414
417
|
async with agent_model.request_stream(messages, model_settings) as model_response:
|
|
415
|
-
usage.requests += 1
|
|
418
|
+
run_context.usage.requests += 1
|
|
416
419
|
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
|
|
417
420
|
# We want to end the "model request" span here, but we can't exit the context manager
|
|
418
421
|
# in the traditional way
|
|
@@ -448,7 +451,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
448
451
|
yield result.StreamedRunResult(
|
|
449
452
|
messages,
|
|
450
453
|
new_message_index,
|
|
451
|
-
usage,
|
|
452
454
|
usage_limits,
|
|
453
455
|
result_stream,
|
|
454
456
|
self._result_schema,
|
|
@@ -472,8 +474,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
472
474
|
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
473
475
|
# the model_response should have been fully streamed by now, we can add its usage
|
|
474
476
|
model_response_usage = model_response.usage()
|
|
475
|
-
usage
|
|
476
|
-
usage_limits.check_tokens(usage)
|
|
477
|
+
run_context.usage.incr(model_response_usage)
|
|
478
|
+
usage_limits.check_tokens(run_context.usage)
|
|
477
479
|
|
|
478
480
|
@contextmanager
|
|
479
481
|
def override(
|
|
@@ -614,7 +616,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
614
616
|
#> success (no tool calls)
|
|
615
617
|
```
|
|
616
618
|
"""
|
|
617
|
-
self._result_validators.append(_result.ResultValidator(func))
|
|
619
|
+
self._result_validators.append(_result.ResultValidator[AgentDeps, Any](func))
|
|
618
620
|
return func
|
|
619
621
|
|
|
620
622
|
@overload
|
|
@@ -784,14 +786,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
784
786
|
|
|
785
787
|
self._function_tools[tool.name] = tool
|
|
786
788
|
|
|
787
|
-
async def _get_model(self, model: models.Model | models.KnownModelName | None) ->
|
|
789
|
+
async def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model:
|
|
788
790
|
"""Create a model configured for this agent.
|
|
789
791
|
|
|
790
792
|
Args:
|
|
791
793
|
model: model to use for this run, required if `model` was not set when creating the agent.
|
|
792
794
|
|
|
793
795
|
Returns:
|
|
794
|
-
|
|
796
|
+
The model used
|
|
795
797
|
"""
|
|
796
798
|
model_: models.Model
|
|
797
799
|
if some_model := self._override_model:
|
|
@@ -802,18 +804,15 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
802
804
|
'(Even when `override(model=...)` is customizing the model that will actually be called)'
|
|
803
805
|
)
|
|
804
806
|
model_ = some_model.value
|
|
805
|
-
mode_selection = 'override-model'
|
|
806
807
|
elif model is not None:
|
|
807
808
|
model_ = models.infer_model(model)
|
|
808
|
-
mode_selection = 'custom'
|
|
809
809
|
elif self.model is not None:
|
|
810
810
|
# noinspection PyTypeChecker
|
|
811
811
|
model_ = self.model = models.infer_model(self.model)
|
|
812
|
-
mode_selection = 'from-agent'
|
|
813
812
|
else:
|
|
814
813
|
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
|
|
815
814
|
|
|
816
|
-
return model_
|
|
815
|
+
return model_
|
|
817
816
|
|
|
818
817
|
async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
|
|
819
818
|
"""Build tools and create an agent model."""
|
|
@@ -835,14 +834,25 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
835
834
|
async def _prepare_messages(
|
|
836
835
|
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
|
|
837
836
|
) -> list[_messages.ModelMessage]:
|
|
837
|
+
try:
|
|
838
|
+
ctx_messages = _messages_ctx_var.get()
|
|
839
|
+
except LookupError:
|
|
840
|
+
messages: list[_messages.ModelMessage] = []
|
|
841
|
+
else:
|
|
842
|
+
if ctx_messages.used:
|
|
843
|
+
messages = []
|
|
844
|
+
else:
|
|
845
|
+
messages = ctx_messages.messages
|
|
846
|
+
ctx_messages.used = True
|
|
847
|
+
|
|
838
848
|
if message_history:
|
|
839
849
|
# shallow copy messages
|
|
840
|
-
messages
|
|
850
|
+
messages.extend(message_history)
|
|
841
851
|
messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
|
|
842
852
|
else:
|
|
843
853
|
parts = await self._sys_parts(run_context)
|
|
844
854
|
parts.append(_messages.UserPromptPart(user_prompt))
|
|
845
|
-
messages
|
|
855
|
+
messages.append(_messages.ModelRequest(parts))
|
|
846
856
|
|
|
847
857
|
return messages
|
|
848
858
|
|
|
@@ -864,11 +874,15 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
864
874
|
else:
|
|
865
875
|
tool_calls.append(part)
|
|
866
876
|
|
|
867
|
-
if
|
|
877
|
+
# At the moment, we prioritize at least executing tool calls if they are present.
|
|
878
|
+
# In the future, we'd consider making this configurable at the agent or run level.
|
|
879
|
+
# This accounts for cases like anthropic returns that might contain a text response
|
|
880
|
+
# and a tool call response, where the text response just indicates the tool call will happen.
|
|
881
|
+
if tool_calls:
|
|
882
|
+
return await self._handle_structured_response(tool_calls, run_context)
|
|
883
|
+
elif texts:
|
|
868
884
|
text = '\n\n'.join(texts)
|
|
869
885
|
return await self._handle_text_response(text, run_context)
|
|
870
|
-
elif tool_calls:
|
|
871
|
-
return await self._handle_structured_response(tool_calls, run_context)
|
|
872
886
|
else:
|
|
873
887
|
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
874
888
|
|
|
@@ -1115,8 +1129,59 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1115
1129
|
self.name = name
|
|
1116
1130
|
return
|
|
1117
1131
|
|
|
1132
|
+
@property
|
|
1133
|
+
@deprecated(
|
|
1134
|
+
'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None
|
|
1135
|
+
)
|
|
1136
|
+
def last_run_messages(self) -> list[_messages.ModelMessage]:
|
|
1137
|
+
raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
|
|
1138
|
+
|
|
1139
|
+
|
|
1140
|
+
@dataclasses.dataclass
|
|
1141
|
+
class _RunMessages:
|
|
1142
|
+
messages: list[_messages.ModelMessage]
|
|
1143
|
+
used: bool = False
|
|
1144
|
+
|
|
1145
|
+
|
|
1146
|
+
_messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var')
|
|
1147
|
+
|
|
1148
|
+
|
|
1149
|
+
@contextmanager
|
|
1150
|
+
def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
|
|
1151
|
+
"""Context manager to access the messages used in a [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] call.
|
|
1152
|
+
|
|
1153
|
+
Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information.
|
|
1154
|
+
|
|
1155
|
+
Examples:
|
|
1156
|
+
```python
|
|
1157
|
+
from pydantic_ai import Agent, capture_run_messages
|
|
1158
|
+
|
|
1159
|
+
agent = Agent('test')
|
|
1160
|
+
|
|
1161
|
+
with capture_run_messages() as messages:
|
|
1162
|
+
try:
|
|
1163
|
+
result = agent.run_sync('foobar')
|
|
1164
|
+
except Exception:
|
|
1165
|
+
print(messages)
|
|
1166
|
+
raise
|
|
1167
|
+
```
|
|
1168
|
+
|
|
1169
|
+
!!! note
|
|
1170
|
+
If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
|
|
1171
|
+
`messages` will represent the messages exchanged during the first call only.
|
|
1172
|
+
"""
|
|
1173
|
+
try:
|
|
1174
|
+
yield _messages_ctx_var.get().messages
|
|
1175
|
+
except LookupError:
|
|
1176
|
+
messages: list[_messages.ModelMessage] = []
|
|
1177
|
+
token = _messages_ctx_var.set(_RunMessages(messages))
|
|
1178
|
+
try:
|
|
1179
|
+
yield messages
|
|
1180
|
+
finally:
|
|
1181
|
+
_messages_ctx_var.reset(token)
|
|
1182
|
+
|
|
1118
1183
|
|
|
1119
|
-
@dataclass
|
|
1184
|
+
@dataclasses.dataclass
|
|
1120
1185
|
class _MarkFinalResult(Generic[ResultData]):
|
|
1121
1186
|
"""Marker class to indicate that the result is the final result.
|
|
1122
1187
|
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -444,7 +444,8 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent:
|
|
|
444
444
|
if isinstance(item, ToolCallPart):
|
|
445
445
|
parts.append(_function_call_part_from_call(item))
|
|
446
446
|
elif isinstance(item, TextPart):
|
|
447
|
-
|
|
447
|
+
if item.content:
|
|
448
|
+
parts.append(_GeminiTextPart(text=item.content))
|
|
448
449
|
else:
|
|
449
450
|
assert_never(item)
|
|
450
451
|
return _GeminiContent(role='model', parts=parts)
|
|
@@ -701,7 +702,7 @@ class _GeminiJsonSchema:
|
|
|
701
702
|
|
|
702
703
|
def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
|
|
703
704
|
schema.pop('title', None)
|
|
704
|
-
schema.pop('default',
|
|
705
|
+
default = schema.pop('default', _utils.UNSET)
|
|
705
706
|
if ref := schema.pop('$ref', None):
|
|
706
707
|
# noinspection PyTypeChecker
|
|
707
708
|
key = re.sub(r'^#/\$defs/', '', ref)
|
|
@@ -714,8 +715,14 @@ class _GeminiJsonSchema:
|
|
|
714
715
|
return
|
|
715
716
|
|
|
716
717
|
if any_of := schema.get('anyOf'):
|
|
717
|
-
for
|
|
718
|
-
self._simplify(
|
|
718
|
+
for item_schema in any_of:
|
|
719
|
+
self._simplify(item_schema, refs_stack)
|
|
720
|
+
if len(any_of) == 2 and {'type': 'null'} in any_of and default is None:
|
|
721
|
+
for item_schema in any_of:
|
|
722
|
+
if item_schema != {'type': 'null'}:
|
|
723
|
+
schema.clear()
|
|
724
|
+
schema.update(item_schema)
|
|
725
|
+
return
|
|
719
726
|
|
|
720
727
|
type_ = schema.get('type')
|
|
721
728
|
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -8,6 +8,7 @@ from datetime import datetime, timezone
|
|
|
8
8
|
from itertools import chain
|
|
9
9
|
from typing import Any, Callable, Literal, Union
|
|
10
10
|
|
|
11
|
+
import pydantic_core
|
|
11
12
|
from httpx import AsyncClient as AsyncHTTPClient, Timeout
|
|
12
13
|
from typing_extensions import assert_never
|
|
13
14
|
|
|
@@ -39,7 +40,6 @@ from . import (
|
|
|
39
40
|
)
|
|
40
41
|
|
|
41
42
|
try:
|
|
42
|
-
from json_repair import repair_json
|
|
43
43
|
from mistralai import (
|
|
44
44
|
UNSET,
|
|
45
45
|
CompletionChunk as MistralCompletionChunk,
|
|
@@ -198,11 +198,10 @@ class MistralAgentModel(AgentModel):
|
|
|
198
198
|
"""Create a streaming completion request to the Mistral model."""
|
|
199
199
|
response: MistralEventStreamAsync[MistralCompletionEvent] | None
|
|
200
200
|
mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
201
|
-
|
|
202
201
|
model_settings = model_settings or {}
|
|
203
202
|
|
|
204
203
|
if self.result_tools and self.function_tools or self.function_tools:
|
|
205
|
-
# Function Calling
|
|
204
|
+
# Function Calling
|
|
206
205
|
response = await self.client.chat.stream_async(
|
|
207
206
|
model=str(self.model_name),
|
|
208
207
|
messages=mistral_messages,
|
|
@@ -218,9 +217,9 @@ class MistralAgentModel(AgentModel):
|
|
|
218
217
|
elif self.result_tools:
|
|
219
218
|
# Json Mode
|
|
220
219
|
parameters_json_schemas = [tool.parameters_json_schema for tool in self.result_tools]
|
|
221
|
-
|
|
222
220
|
user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
|
|
223
221
|
mistral_messages.append(user_output_format_message)
|
|
222
|
+
|
|
224
223
|
response = await self.client.chat.stream_async(
|
|
225
224
|
model=str(self.model_name),
|
|
226
225
|
messages=mistral_messages,
|
|
@@ -270,12 +269,13 @@ class MistralAgentModel(AgentModel):
|
|
|
270
269
|
@staticmethod
|
|
271
270
|
def _process_response(response: MistralChatCompletionResponse) -> ModelResponse:
|
|
272
271
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
272
|
+
assert response.choices, 'Unexpected empty response choice.'
|
|
273
|
+
|
|
273
274
|
if response.created:
|
|
274
275
|
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
275
276
|
else:
|
|
276
277
|
timestamp = _now_utc()
|
|
277
278
|
|
|
278
|
-
assert response.choices, 'Unexpected empty response choice.'
|
|
279
279
|
choice = response.choices[0]
|
|
280
280
|
content = choice.message.content
|
|
281
281
|
tool_calls = choice.message.tool_calls
|
|
@@ -546,20 +546,15 @@ class MistralStreamStructuredResponse(StreamStructuredResponse):
|
|
|
546
546
|
calls.append(tool)
|
|
547
547
|
|
|
548
548
|
elif self._delta_content and self._result_tools:
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
output_json, dict
|
|
553
|
-
), f'Expected repair_json as type dict, invalid type: {type(output_json)}'
|
|
549
|
+
output_json: dict[str, Any] | None = pydantic_core.from_json(
|
|
550
|
+
self._delta_content, allow_partial='trailing-strings'
|
|
551
|
+
)
|
|
554
552
|
|
|
555
553
|
if output_json:
|
|
556
554
|
for result_tool in self._result_tools.values():
|
|
557
|
-
# NOTE: Additional verification to prevent JSON validation to crash in `
|
|
555
|
+
# NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
|
|
558
556
|
# Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
|
|
559
|
-
#
|
|
560
|
-
# when `{"response":` then `repair_json` sets `{"response": ""}` (type not found default str)
|
|
561
|
-
# when `{"response": {` then `repair_json` sets `{"response": {}}` (type found)
|
|
562
|
-
# This ensures it's corrected to `{"response": {}}` and other required parameters and type.
|
|
557
|
+
# Example with BaseModel and required fields.
|
|
563
558
|
if not self._validate_required_json_schema(output_json, result_tool.parameters_json_schema):
|
|
564
559
|
continue
|
|
565
560
|
|
pydantic_ai/models/ollama.py
CHANGED
|
@@ -71,6 +71,7 @@ class OllamaModel(Model):
|
|
|
71
71
|
model_name: OllamaModelName,
|
|
72
72
|
*,
|
|
73
73
|
base_url: str | None = 'http://localhost:11434/v1/',
|
|
74
|
+
api_key: str = 'ollama',
|
|
74
75
|
openai_client: AsyncOpenAI | None = None,
|
|
75
76
|
http_client: AsyncHTTPClient | None = None,
|
|
76
77
|
):
|
|
@@ -83,6 +84,8 @@ class OllamaModel(Model):
|
|
|
83
84
|
model_name: The name of the Ollama model to use. List of models available [here](https://ollama.com/library)
|
|
84
85
|
You must first download the model (`ollama pull <MODEL-NAME>`) in order to use the model
|
|
85
86
|
base_url: The base url for the ollama requests. The default value is the ollama default
|
|
87
|
+
api_key: The API key to use for authentication. Defaults to 'ollama' for local instances,
|
|
88
|
+
but can be customized for proxy setups that require authentication
|
|
86
89
|
openai_client: An existing
|
|
87
90
|
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
88
91
|
client to use, if provided, `base_url` and `http_client` must be `None`.
|
|
@@ -96,7 +99,7 @@ class OllamaModel(Model):
|
|
|
96
99
|
else:
|
|
97
100
|
# API key is not required for ollama but a value is required to create the client
|
|
98
101
|
http_client_ = http_client or cached_async_http_client()
|
|
99
|
-
oai_client = AsyncOpenAI(base_url=base_url, api_key=
|
|
102
|
+
oai_client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client_)
|
|
100
103
|
self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client)
|
|
101
104
|
|
|
102
105
|
async def agent_model(
|
pydantic_ai/models/test.py
CHANGED
|
@@ -16,6 +16,7 @@ from ..messages import (
|
|
|
16
16
|
ModelMessage,
|
|
17
17
|
ModelRequest,
|
|
18
18
|
ModelResponse,
|
|
19
|
+
ModelResponsePart,
|
|
19
20
|
RetryPromptPart,
|
|
20
21
|
TextPart,
|
|
21
22
|
ToolCallPart,
|
|
@@ -177,13 +178,23 @@ class TestAgentModel(AgentModel):
|
|
|
177
178
|
# check if there are any retry prompts, if so retry them
|
|
178
179
|
new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
|
|
179
180
|
if new_retry_names:
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
181
|
+
# Handle retries for both function tools and result tools
|
|
182
|
+
# Check function tools first
|
|
183
|
+
retry_parts: list[ModelResponsePart] = [
|
|
184
|
+
ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
|
|
185
|
+
for name, args in self.tool_calls
|
|
186
|
+
if name in new_retry_names
|
|
187
|
+
]
|
|
188
|
+
# Check result tools
|
|
189
|
+
if self.result_tools:
|
|
190
|
+
retry_parts.extend(
|
|
191
|
+
[
|
|
192
|
+
ToolCallPart.from_raw_args(tool.name, self.gen_tool_args(tool))
|
|
193
|
+
for tool in self.result_tools
|
|
194
|
+
if tool.name in new_retry_names
|
|
195
|
+
]
|
|
196
|
+
)
|
|
197
|
+
return ModelResponse(parts=retry_parts)
|
|
187
198
|
|
|
188
199
|
if response_text := self.result.left:
|
|
189
200
|
if response_text.value is None:
|
pydantic_ai/result.py
CHANGED
|
@@ -2,11 +2,13 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
5
|
+
from copy import copy
|
|
5
6
|
from dataclasses import dataclass, field
|
|
6
7
|
from datetime import datetime
|
|
7
|
-
from typing import Generic,
|
|
8
|
+
from typing import Generic, Union, cast
|
|
8
9
|
|
|
9
10
|
import logfire_api
|
|
11
|
+
from typing_extensions import TypeVar
|
|
10
12
|
|
|
11
13
|
from . import _result, _utils, exceptions, messages as _messages, models
|
|
12
14
|
from .settings import UsageLimits
|
|
@@ -14,21 +16,37 @@ from .tools import AgentDeps, RunContext
|
|
|
14
16
|
|
|
15
17
|
__all__ = (
|
|
16
18
|
'ResultData',
|
|
19
|
+
'ResultValidatorFunc',
|
|
17
20
|
'Usage',
|
|
18
21
|
'RunResult',
|
|
19
22
|
'StreamedRunResult',
|
|
20
23
|
)
|
|
21
24
|
|
|
22
25
|
|
|
23
|
-
ResultData = TypeVar('ResultData')
|
|
26
|
+
ResultData = TypeVar('ResultData', default=str)
|
|
24
27
|
"""Type variable for the result data of a run."""
|
|
25
28
|
|
|
29
|
+
ResultValidatorFunc = Union[
|
|
30
|
+
Callable[[RunContext[AgentDeps], ResultData], ResultData],
|
|
31
|
+
Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
|
|
32
|
+
Callable[[ResultData], ResultData],
|
|
33
|
+
Callable[[ResultData], Awaitable[ResultData]],
|
|
34
|
+
]
|
|
35
|
+
"""
|
|
36
|
+
A function that always takes `ResultData` and returns `ResultData` and:
|
|
37
|
+
|
|
38
|
+
* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
|
|
39
|
+
* may or may not be async
|
|
40
|
+
|
|
41
|
+
Usage `ResultValidatorFunc[AgentDeps, ResultData]`.
|
|
42
|
+
"""
|
|
43
|
+
|
|
26
44
|
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
27
45
|
|
|
28
46
|
|
|
29
47
|
@dataclass
|
|
30
48
|
class Usage:
|
|
31
|
-
"""LLM usage associated
|
|
49
|
+
"""LLM usage associated with a request or run.
|
|
32
50
|
|
|
33
51
|
Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests.
|
|
34
52
|
|
|
@@ -36,7 +54,7 @@ class Usage:
|
|
|
36
54
|
"""
|
|
37
55
|
|
|
38
56
|
requests: int = 0
|
|
39
|
-
"""Number of requests made."""
|
|
57
|
+
"""Number of requests made to the LLM API."""
|
|
40
58
|
request_tokens: int | None = None
|
|
41
59
|
"""Tokens used in processing requests."""
|
|
42
60
|
response_tokens: int | None = None
|
|
@@ -46,25 +64,33 @@ class Usage:
|
|
|
46
64
|
details: dict[str, int] | None = None
|
|
47
65
|
"""Any extra details returned by the model."""
|
|
48
66
|
|
|
49
|
-
def
|
|
50
|
-
"""
|
|
67
|
+
def incr(self, incr_usage: Usage, *, requests: int = 0) -> None:
|
|
68
|
+
"""Increment the usage in place.
|
|
51
69
|
|
|
52
|
-
|
|
70
|
+
Args:
|
|
71
|
+
incr_usage: The usage to increment by.
|
|
72
|
+
requests: The number of requests to increment by in addition to `incr_usage.requests`.
|
|
53
73
|
"""
|
|
54
|
-
|
|
74
|
+
self.requests += requests
|
|
55
75
|
for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
|
|
56
76
|
self_value = getattr(self, f)
|
|
57
|
-
other_value = getattr(
|
|
77
|
+
other_value = getattr(incr_usage, f)
|
|
58
78
|
if self_value is not None or other_value is not None:
|
|
59
|
-
|
|
79
|
+
setattr(self, f, (self_value or 0) + (other_value or 0))
|
|
80
|
+
|
|
81
|
+
if incr_usage.details:
|
|
82
|
+
self.details = self.details or {}
|
|
83
|
+
for key, value in incr_usage.details.items():
|
|
84
|
+
self.details[key] = self.details.get(key, 0) + value
|
|
60
85
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
details = details or {}
|
|
64
|
-
for key, value in other.details.items():
|
|
65
|
-
details[key] = details.get(key, 0) + value
|
|
86
|
+
def __add__(self, other: Usage) -> Usage:
|
|
87
|
+
"""Add two Usages together.
|
|
66
88
|
|
|
67
|
-
|
|
89
|
+
This is provided so it's trivial to sum usage information from multiple requests and runs.
|
|
90
|
+
"""
|
|
91
|
+
new_usage = copy(self)
|
|
92
|
+
new_usage.incr(other)
|
|
93
|
+
return new_usage
|
|
68
94
|
|
|
69
95
|
|
|
70
96
|
@dataclass
|
|
@@ -119,8 +145,6 @@ class RunResult(_BaseRunResult[ResultData]):
|
|
|
119
145
|
class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
|
|
120
146
|
"""Result of a streamed run that returns structured data via a tool call."""
|
|
121
147
|
|
|
122
|
-
usage_so_far: Usage
|
|
123
|
-
"""Usage of the run up until the last request."""
|
|
124
148
|
_usage_limits: UsageLimits | None
|
|
125
149
|
_stream_response: models.EitherStreamedResponse
|
|
126
150
|
_result_schema: _result.ResultSchema[ResultData] | None
|
|
@@ -289,7 +313,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
289
313
|
!!! note
|
|
290
314
|
This won't return the full usage until the stream is finished.
|
|
291
315
|
"""
|
|
292
|
-
return self.
|
|
316
|
+
return self._run_ctx.usage + self._stream_response.usage()
|
|
293
317
|
|
|
294
318
|
def timestamp(self) -> datetime:
|
|
295
319
|
"""Get the timestamp of the response."""
|
pydantic_ai/settings.py
CHANGED
|
@@ -22,6 +22,7 @@ class ModelSettings(TypedDict, total=False):
|
|
|
22
22
|
"""The maximum number of tokens to generate before stopping.
|
|
23
23
|
|
|
24
24
|
Supported by:
|
|
25
|
+
|
|
25
26
|
* Gemini
|
|
26
27
|
* Anthropic
|
|
27
28
|
* OpenAI
|
|
@@ -37,6 +38,7 @@ class ModelSettings(TypedDict, total=False):
|
|
|
37
38
|
Note that even with `temperature` of `0.0`, the results will not be fully deterministic.
|
|
38
39
|
|
|
39
40
|
Supported by:
|
|
41
|
+
|
|
40
42
|
* Gemini
|
|
41
43
|
* Anthropic
|
|
42
44
|
* OpenAI
|
|
@@ -51,6 +53,7 @@ class ModelSettings(TypedDict, total=False):
|
|
|
51
53
|
You should either alter `temperature` or `top_p`, but not both.
|
|
52
54
|
|
|
53
55
|
Supported by:
|
|
56
|
+
|
|
54
57
|
* Gemini
|
|
55
58
|
* Anthropic
|
|
56
59
|
* OpenAI
|
|
@@ -61,6 +64,7 @@ class ModelSettings(TypedDict, total=False):
|
|
|
61
64
|
"""Override the client-level default timeout for a request, in seconds.
|
|
62
65
|
|
|
63
66
|
Supported by:
|
|
67
|
+
|
|
64
68
|
* Gemini
|
|
65
69
|
* Anthropic
|
|
66
70
|
* OpenAI
|
|
@@ -132,6 +136,6 @@ class UsageLimits:
|
|
|
132
136
|
f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})'
|
|
133
137
|
)
|
|
134
138
|
|
|
135
|
-
total_tokens =
|
|
139
|
+
total_tokens = usage.total_tokens or 0
|
|
136
140
|
if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
|
|
137
141
|
raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
|
pydantic_ai/tools.py
CHANGED
|
@@ -4,7 +4,7 @@ import dataclasses
|
|
|
4
4
|
import inspect
|
|
5
5
|
from collections.abc import Awaitable
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
|
-
from typing import Any, Callable, Generic, TypeVar, Union, cast
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast
|
|
8
8
|
|
|
9
9
|
from pydantic import ValidationError
|
|
10
10
|
from pydantic_core import SchemaValidator
|
|
@@ -13,10 +13,12 @@ from typing_extensions import Concatenate, ParamSpec, TypeAlias
|
|
|
13
13
|
from . import _pydantic, _utils, messages as _messages, models
|
|
14
14
|
from .exceptions import ModelRetry, UnexpectedModelBehavior
|
|
15
15
|
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from .result import Usage
|
|
18
|
+
|
|
16
19
|
__all__ = (
|
|
17
20
|
'AgentDeps',
|
|
18
21
|
'RunContext',
|
|
19
|
-
'ResultValidatorFunc',
|
|
20
22
|
'SystemPromptFunc',
|
|
21
23
|
'ToolFuncContext',
|
|
22
24
|
'ToolFuncPlain',
|
|
@@ -38,14 +40,20 @@ class RunContext(Generic[AgentDeps]):
|
|
|
38
40
|
|
|
39
41
|
deps: AgentDeps
|
|
40
42
|
"""Dependencies for the agent."""
|
|
41
|
-
retry: int
|
|
42
|
-
"""Number of retries so far."""
|
|
43
|
-
messages: list[_messages.ModelMessage]
|
|
44
|
-
"""Messages exchanged in the conversation so far."""
|
|
45
|
-
tool_name: str | None
|
|
46
|
-
"""Name of the tool being called."""
|
|
47
43
|
model: models.Model
|
|
48
44
|
"""The model used in this run."""
|
|
45
|
+
usage: Usage
|
|
46
|
+
"""LLM usage associated with the run."""
|
|
47
|
+
prompt: str
|
|
48
|
+
"""The original user prompt passed to the run."""
|
|
49
|
+
messages: list[_messages.ModelMessage] = field(default_factory=list)
|
|
50
|
+
"""Messages exchanged in the conversation so far."""
|
|
51
|
+
tool_name: str | None = None
|
|
52
|
+
"""Name of the tool being called."""
|
|
53
|
+
retry: int = 0
|
|
54
|
+
"""Number of retries so far."""
|
|
55
|
+
run_step: int = 0
|
|
56
|
+
"""The current step in the run."""
|
|
49
57
|
|
|
50
58
|
def replace_with(
|
|
51
59
|
self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
|
|
@@ -73,21 +81,6 @@ SystemPromptFunc = Union[
|
|
|
73
81
|
Usage `SystemPromptFunc[AgentDeps]`.
|
|
74
82
|
"""
|
|
75
83
|
|
|
76
|
-
ResultData = TypeVar('ResultData')
|
|
77
|
-
|
|
78
|
-
ResultValidatorFunc = Union[
|
|
79
|
-
Callable[[RunContext[AgentDeps], ResultData], ResultData],
|
|
80
|
-
Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
|
|
81
|
-
Callable[[ResultData], ResultData],
|
|
82
|
-
Callable[[ResultData], Awaitable[ResultData]],
|
|
83
|
-
]
|
|
84
|
-
"""
|
|
85
|
-
A function that always takes `ResultData` and returns `ResultData`,
|
|
86
|
-
but may or maybe not take `CallInfo` as a first argument, and may or may not be async.
|
|
87
|
-
|
|
88
|
-
Usage `ResultValidator[AgentDeps, ResultData]`.
|
|
89
|
-
"""
|
|
90
|
-
|
|
91
84
|
ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], Any]
|
|
92
85
|
"""A tool function that takes `RunContext` as the first argument.
|
|
93
86
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.16
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -36,7 +36,6 @@ Requires-Dist: groq>=0.12.0; extra == 'groq'
|
|
|
36
36
|
Provides-Extra: logfire
|
|
37
37
|
Requires-Dist: logfire>=2.3; extra == 'logfire'
|
|
38
38
|
Provides-Extra: mistral
|
|
39
|
-
Requires-Dist: json-repair>=0.30.3; extra == 'mistral'
|
|
40
39
|
Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
|
|
41
40
|
Provides-Extra: openai
|
|
42
41
|
Requires-Dist: openai>=1.54.3; extra == 'openai'
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
pydantic_ai/__init__.py,sha256=FbYetEgT6OO25u2KF5ZnFxKpz5DtnSpfckRXP4mjl8E,489
|
|
2
|
+
pydantic_ai/_griffe.py,sha256=Wqk3AuyeWuPwE5s1GbMeCsERelx1B4QcU9uYZSoko8s,3409
|
|
3
|
+
pydantic_ai/_pydantic.py,sha256=qXi5IsyiYOHeg_-qozCdxkfeqw2z0gBTjqgywBCiJWo,8125
|
|
4
|
+
pydantic_ai/_result.py,sha256=cUSugZQV0n5Z4fFHiMqua-2xs_0S6m-rr-yd6QS3nFE,10317
|
|
5
|
+
pydantic_ai/_system_prompt.py,sha256=MZJWksIoS5GM3Au5lznlcQnC-h7eqwtE7oI5WFgRcOg,1090
|
|
6
|
+
pydantic_ai/_utils.py,sha256=skWNgm89US_x1EpxdRy5wCkghBrm1XgxFCiEh6wAkAo,8753
|
|
7
|
+
pydantic_ai/agent.py,sha256=NJTcPSlqb4Fd-x9pDPuoXGCwFGF1GHcHevutoB0Busw,52333
|
|
8
|
+
pydantic_ai/exceptions.py,sha256=eGDKX6bGhgVxXBzu81Sk3iiAkXr0GUtgT7bD5Rxlqpg,2028
|
|
9
|
+
pydantic_ai/messages.py,sha256=ImbWY8Ft3mxInUQ08EmIWywf4nJBvTiJhmsECRYDkSQ,8968
|
|
10
|
+
pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
pydantic_ai/result.py,sha256=LbZVHZnJnQwgegSz5PtwS9r_ifrJnLRpsa9xjYnHg1g,15549
|
|
12
|
+
pydantic_ai/settings.py,sha256=W8krcFsujjhE03rwckrz39F4Dz_9RwdBSeEF3izK0-Y,4918
|
|
13
|
+
pydantic_ai/tools.py,sha256=mnh3Lvs0Ri0FkqpV1MUooExNN4epTCcBKw6DyZvNSQ8,11745
|
|
14
|
+
pydantic_ai/models/__init__.py,sha256=XHt02IDQAircb-lEkIbIcuabSAIh5_UKnz2V1xN0Glw,10926
|
|
15
|
+
pydantic_ai/models/anthropic.py,sha256=EUZgmvT0jhMDbooBp_jfW0z2cM5jTMuAhVws1XKgaNs,13451
|
|
16
|
+
pydantic_ai/models/function.py,sha256=i7qkS_31aHrTbYVh6OzQ7Cwucz44F5PjT2EJK3GMphw,10573
|
|
17
|
+
pydantic_ai/models/gemini.py,sha256=Sr19D2hN8iEAcoLlzv5883pto90TgEr_xiGlV8hMOwA,28572
|
|
18
|
+
pydantic_ai/models/groq.py,sha256=ZoPkuWJrf78JPnTRfZhi7v0ETgxJKNN5dH8BLWagGGk,15770
|
|
19
|
+
pydantic_ai/models/mistral.py,sha256=xGVI6-b8-9vnFickPPI2cRaHEWLc0jKKUM_vMjipf-U,25894
|
|
20
|
+
pydantic_ai/models/ollama.py,sha256=ELqxhcNcnvQBnadd3gukS01zprUp6v8N_h1P5K-uf6c,4188
|
|
21
|
+
pydantic_ai/models/openai.py,sha256=qFFInL3NbgfGcsAWigxMP5mscp76hC-jJimHc9woU6Y,16518
|
|
22
|
+
pydantic_ai/models/test.py,sha256=u2pdZd9OLXQ_jI6CaVt96udXuIcv0Hfnfqd3pFGmeJM,16514
|
|
23
|
+
pydantic_ai/models/vertexai.py,sha256=DBCBfpvpIhZaMG7cKvRl5rugCZqJqqEFm74uBc45weo,9259
|
|
24
|
+
pydantic_ai_slim-0.0.16.dist-info/METADATA,sha256=4udd7j2erIuMC0ekYgmgQAqsKfhA5sLsKzTcD_QyOeo,2730
|
|
25
|
+
pydantic_ai_slim-0.0.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
26
|
+
pydantic_ai_slim-0.0.16.dist-info/RECORD,,
|
|
@@ -1,26 +0,0 @@
|
|
|
1
|
-
pydantic_ai/__init__.py,sha256=a3ffrVF4eJyylreRa6LbTMF6RI7VIv96zHv2tCWo1kQ,439
|
|
2
|
-
pydantic_ai/_griffe.py,sha256=pRjCJ6B1hhx6k46XJgl9zF6aRYxRmqEZKFok8unp4Iw,3449
|
|
3
|
-
pydantic_ai/_pydantic.py,sha256=qXi5IsyiYOHeg_-qozCdxkfeqw2z0gBTjqgywBCiJWo,8125
|
|
4
|
-
pydantic_ai/_result.py,sha256=iL0oZXvuCEoa37EKGXkEhn90oB_950emKG-uXdlmssM,10317
|
|
5
|
-
pydantic_ai/_system_prompt.py,sha256=MZJWksIoS5GM3Au5lznlcQnC-h7eqwtE7oI5WFgRcOg,1090
|
|
6
|
-
pydantic_ai/_utils.py,sha256=skWNgm89US_x1EpxdRy5wCkghBrm1XgxFCiEh6wAkAo,8753
|
|
7
|
-
pydantic_ai/agent.py,sha256=6iOVgmoGw5wGd0W4ewFgOGJqTLmtJLyl5RLqt3dSegE,49574
|
|
8
|
-
pydantic_ai/exceptions.py,sha256=eGDKX6bGhgVxXBzu81Sk3iiAkXr0GUtgT7bD5Rxlqpg,2028
|
|
9
|
-
pydantic_ai/messages.py,sha256=ImbWY8Ft3mxInUQ08EmIWywf4nJBvTiJhmsECRYDkSQ,8968
|
|
10
|
-
pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
-
pydantic_ai/result.py,sha256=pE8YTCUDS0hXTplEJKSG3UnhMvwGNqaOU-9eC09K4y4,14735
|
|
12
|
-
pydantic_ai/settings.py,sha256=rSNqWYzBoIwC9wZCS2M_aEdCNAy5sOXbi5bXJAuqK08,4923
|
|
13
|
-
pydantic_ai/tools.py,sha256=9ZEhvgylv3pc0_JaahCJgJlFxubCWeEaCNqERl-I9B0,11982
|
|
14
|
-
pydantic_ai/models/__init__.py,sha256=XHt02IDQAircb-lEkIbIcuabSAIh5_UKnz2V1xN0Glw,10926
|
|
15
|
-
pydantic_ai/models/anthropic.py,sha256=EUZgmvT0jhMDbooBp_jfW0z2cM5jTMuAhVws1XKgaNs,13451
|
|
16
|
-
pydantic_ai/models/function.py,sha256=i7qkS_31aHrTbYVh6OzQ7Cwucz44F5PjT2EJK3GMphw,10573
|
|
17
|
-
pydantic_ai/models/gemini.py,sha256=8vdcW4izL9NUGFj6lcD9yIPaakCtsmHauTvKwlTzD14,28207
|
|
18
|
-
pydantic_ai/models/groq.py,sha256=ZoPkuWJrf78JPnTRfZhi7v0ETgxJKNN5dH8BLWagGGk,15770
|
|
19
|
-
pydantic_ai/models/mistral.py,sha256=xZMK2vNLDR4uw1XoAQs-3obeA6c39Q1Qhei8w1JMIow,26458
|
|
20
|
-
pydantic_ai/models/ollama.py,sha256=i3mMXkXu9xL6f4c52Eyx3j4aHKfYoloFondlGHPtkS4,3971
|
|
21
|
-
pydantic_ai/models/openai.py,sha256=qFFInL3NbgfGcsAWigxMP5mscp76hC-jJimHc9woU6Y,16518
|
|
22
|
-
pydantic_ai/models/test.py,sha256=pty5qaudHsSDvdE89HqMj-kmd4UMV9VJI2YGtdfOX1o,15960
|
|
23
|
-
pydantic_ai/models/vertexai.py,sha256=DBCBfpvpIhZaMG7cKvRl5rugCZqJqqEFm74uBc45weo,9259
|
|
24
|
-
pydantic_ai_slim-0.0.14.dist-info/METADATA,sha256=GgTCeoTiFNW1S0pSniLn7ILqGrLK8eaOkOG84aYE06M,2785
|
|
25
|
-
pydantic_ai_slim-0.0.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
26
|
-
pydantic_ai_slim-0.0.14.dist-info/RECORD,,
|
|
File without changes
|