pydantic-ai-slim 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +1 -2
- pydantic_ai/agent.py +71 -60
- pydantic_ai/models/gemini.py +11 -4
- pydantic_ai/models/ollama.py +4 -1
- pydantic_ai/models/test.py +18 -7
- pydantic_ai/result.py +22 -15
- pydantic_ai/settings.py +1 -1
- pydantic_ai/tools.py +16 -7
- {pydantic_ai_slim-0.0.15.dist-info → pydantic_ai_slim-0.0.16.dist-info}/METADATA +1 -1
- {pydantic_ai_slim-0.0.15.dist-info → pydantic_ai_slim-0.0.16.dist-info}/RECORD +11 -11
- {pydantic_ai_slim-0.0.15.dist-info → pydantic_ai_slim-0.0.16.dist-info}/WHEEL +0 -0
pydantic_ai/_griffe.py
CHANGED
|
@@ -4,8 +4,7 @@ import re
|
|
|
4
4
|
from inspect import Signature
|
|
5
5
|
from typing import Any, Callable, Literal, cast
|
|
6
6
|
|
|
7
|
-
from
|
|
8
|
-
from _griffe.models import Docstring, Object as GriffeObject
|
|
7
|
+
from griffe import Docstring, DocstringSectionKind, Object as GriffeObject
|
|
9
8
|
|
|
10
9
|
DocstringStyle = Literal['google', 'numpy', 'sphinx']
|
|
11
10
|
|
pydantic_ai/agent.py
CHANGED
|
@@ -6,7 +6,6 @@ import inspect
|
|
|
6
6
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
7
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
8
8
|
from contextvars import ContextVar
|
|
9
|
-
from dataclasses import dataclass, field
|
|
10
9
|
from types import FrameType
|
|
11
10
|
from typing import Any, Callable, Generic, Literal, cast, final, overload
|
|
12
11
|
|
|
@@ -40,6 +39,16 @@ __all__ = 'Agent', 'capture_run_messages', 'EndStrategy'
|
|
|
40
39
|
|
|
41
40
|
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
42
41
|
|
|
42
|
+
# while waiting for https://github.com/pydantic/logfire/issues/745
|
|
43
|
+
try:
|
|
44
|
+
import logfire._internal.stack_info
|
|
45
|
+
except ImportError:
|
|
46
|
+
pass
|
|
47
|
+
else:
|
|
48
|
+
from pathlib import Path
|
|
49
|
+
|
|
50
|
+
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
|
|
51
|
+
|
|
43
52
|
NoneType = type(None)
|
|
44
53
|
EndStrategy = Literal['early', 'exhaustive']
|
|
45
54
|
"""The strategy for handling multiple tool calls when a final result is found.
|
|
@@ -50,7 +59,7 @@ EndStrategy = Literal['early', 'exhaustive']
|
|
|
50
59
|
|
|
51
60
|
|
|
52
61
|
@final
|
|
53
|
-
@dataclass(init=False)
|
|
62
|
+
@dataclasses.dataclass(init=False)
|
|
54
63
|
class Agent(Generic[AgentDeps, ResultData]):
|
|
55
64
|
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
|
|
56
65
|
|
|
@@ -90,17 +99,17 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
90
99
|
be merged with this value, with the runtime argument taking priority.
|
|
91
100
|
"""
|
|
92
101
|
|
|
93
|
-
_result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
|
|
94
|
-
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
|
|
95
|
-
_allow_text_result: bool = field(repr=False)
|
|
96
|
-
_system_prompts: tuple[str, ...] = field(repr=False)
|
|
97
|
-
_function_tools: dict[str, Tool[AgentDeps]] = field(repr=False)
|
|
98
|
-
_default_retries: int = field(repr=False)
|
|
99
|
-
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
|
|
100
|
-
_deps_type: type[AgentDeps] = field(repr=False)
|
|
101
|
-
_max_result_retries: int = field(repr=False)
|
|
102
|
-
_override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
|
|
103
|
-
_override_model: _utils.Option[models.Model] = field(default=None, repr=False)
|
|
102
|
+
_result_schema: _result.ResultSchema[ResultData] | None = dataclasses.field(repr=False)
|
|
103
|
+
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = dataclasses.field(repr=False)
|
|
104
|
+
_allow_text_result: bool = dataclasses.field(repr=False)
|
|
105
|
+
_system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
|
|
106
|
+
_function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False)
|
|
107
|
+
_default_retries: int = dataclasses.field(repr=False)
|
|
108
|
+
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False)
|
|
109
|
+
_deps_type: type[AgentDeps] = dataclasses.field(repr=False)
|
|
110
|
+
_max_result_retries: int = dataclasses.field(repr=False)
|
|
111
|
+
_override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False)
|
|
112
|
+
_override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
|
|
104
113
|
|
|
105
114
|
def __init__(
|
|
106
115
|
self,
|
|
@@ -184,6 +193,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
184
193
|
deps: AgentDeps = None,
|
|
185
194
|
model_settings: ModelSettings | None = None,
|
|
186
195
|
usage_limits: UsageLimits | None = None,
|
|
196
|
+
usage: result.Usage | None = None,
|
|
187
197
|
infer_name: bool = True,
|
|
188
198
|
) -> result.RunResult[ResultData]:
|
|
189
199
|
"""Run the agent with a user prompt in async mode.
|
|
@@ -206,6 +216,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
206
216
|
deps: Optional dependencies to use for this run.
|
|
207
217
|
model_settings: Optional settings to use for this model's request.
|
|
208
218
|
usage_limits: Optional limits on model request count or token usage.
|
|
219
|
+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
209
220
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
210
221
|
|
|
211
222
|
Returns:
|
|
@@ -213,7 +224,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
213
224
|
"""
|
|
214
225
|
if infer_name and self.name is None:
|
|
215
226
|
self._infer_name(inspect.currentframe())
|
|
216
|
-
model_used
|
|
227
|
+
model_used = await self._get_model(model)
|
|
217
228
|
|
|
218
229
|
deps = self._get_deps(deps)
|
|
219
230
|
new_message_index = len(message_history) if message_history else 0
|
|
@@ -222,40 +233,36 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
222
233
|
'{agent_name} run {prompt=}',
|
|
223
234
|
prompt=user_prompt,
|
|
224
235
|
agent=self,
|
|
225
|
-
mode_selection=mode_selection,
|
|
226
236
|
model_name=model_used.name(),
|
|
227
237
|
agent_name=self.name or 'agent',
|
|
228
238
|
) as run_span:
|
|
229
|
-
run_context = RunContext(deps,
|
|
239
|
+
run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
|
|
230
240
|
messages = await self._prepare_messages(user_prompt, message_history, run_context)
|
|
231
241
|
run_context.messages = messages
|
|
232
242
|
|
|
233
243
|
for tool in self._function_tools.values():
|
|
234
244
|
tool.current_retry = 0
|
|
235
245
|
|
|
236
|
-
usage = result.Usage(requests=0)
|
|
237
246
|
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
238
247
|
usage_limits = usage_limits or UsageLimits()
|
|
239
248
|
|
|
240
|
-
run_step = 0
|
|
241
249
|
while True:
|
|
242
|
-
usage_limits.check_before_request(usage)
|
|
250
|
+
usage_limits.check_before_request(run_context.usage)
|
|
243
251
|
|
|
244
|
-
run_step += 1
|
|
245
|
-
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
252
|
+
run_context.run_step += 1
|
|
253
|
+
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
|
|
246
254
|
agent_model = await self._prepare_model(run_context)
|
|
247
255
|
|
|
248
|
-
with _logfire.span('model request', run_step=run_step) as model_req_span:
|
|
256
|
+
with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
|
|
249
257
|
model_response, request_usage = await agent_model.request(messages, model_settings)
|
|
250
258
|
model_req_span.set_attribute('response', model_response)
|
|
251
259
|
model_req_span.set_attribute('usage', request_usage)
|
|
252
260
|
|
|
253
261
|
messages.append(model_response)
|
|
254
|
-
usage
|
|
255
|
-
usage
|
|
256
|
-
usage_limits.check_tokens(request_usage)
|
|
262
|
+
run_context.usage.incr(request_usage, requests=1)
|
|
263
|
+
usage_limits.check_tokens(run_context.usage)
|
|
257
264
|
|
|
258
|
-
with _logfire.span('handle model response', run_step=run_step) as handle_span:
|
|
265
|
+
with _logfire.span('handle model response', run_step=run_context.run_step) as handle_span:
|
|
259
266
|
final_result, tool_responses = await self._handle_model_response(model_response, run_context)
|
|
260
267
|
|
|
261
268
|
if tool_responses:
|
|
@@ -266,10 +273,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
266
273
|
if final_result is not None:
|
|
267
274
|
result_data = final_result.data
|
|
268
275
|
run_span.set_attribute('all_messages', messages)
|
|
269
|
-
run_span.set_attribute('usage', usage)
|
|
276
|
+
run_span.set_attribute('usage', run_context.usage)
|
|
270
277
|
handle_span.set_attribute('result', result_data)
|
|
271
278
|
handle_span.message = 'handle model response -> final result'
|
|
272
|
-
return result.RunResult(messages, new_message_index, result_data, usage)
|
|
279
|
+
return result.RunResult(messages, new_message_index, result_data, run_context.usage)
|
|
273
280
|
else:
|
|
274
281
|
# continue the conversation
|
|
275
282
|
handle_span.set_attribute('tool_responses', tool_responses)
|
|
@@ -285,6 +292,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
285
292
|
deps: AgentDeps = None,
|
|
286
293
|
model_settings: ModelSettings | None = None,
|
|
287
294
|
usage_limits: UsageLimits | None = None,
|
|
295
|
+
usage: result.Usage | None = None,
|
|
288
296
|
infer_name: bool = True,
|
|
289
297
|
) -> result.RunResult[ResultData]:
|
|
290
298
|
"""Run the agent with a user prompt synchronously.
|
|
@@ -311,6 +319,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
311
319
|
deps: Optional dependencies to use for this run.
|
|
312
320
|
model_settings: Optional settings to use for this model's request.
|
|
313
321
|
usage_limits: Optional limits on model request count or token usage.
|
|
322
|
+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
314
323
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
315
324
|
|
|
316
325
|
Returns:
|
|
@@ -326,6 +335,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
326
335
|
deps=deps,
|
|
327
336
|
model_settings=model_settings,
|
|
328
337
|
usage_limits=usage_limits,
|
|
338
|
+
usage=usage,
|
|
329
339
|
infer_name=False,
|
|
330
340
|
)
|
|
331
341
|
)
|
|
@@ -340,6 +350,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
340
350
|
deps: AgentDeps = None,
|
|
341
351
|
model_settings: ModelSettings | None = None,
|
|
342
352
|
usage_limits: UsageLimits | None = None,
|
|
353
|
+
usage: result.Usage | None = None,
|
|
343
354
|
infer_name: bool = True,
|
|
344
355
|
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
|
|
345
356
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
@@ -363,6 +374,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
363
374
|
deps: Optional dependencies to use for this run.
|
|
364
375
|
model_settings: Optional settings to use for this model's request.
|
|
365
376
|
usage_limits: Optional limits on model request count or token usage.
|
|
377
|
+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
366
378
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
367
379
|
|
|
368
380
|
Returns:
|
|
@@ -372,7 +384,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
372
384
|
# f_back because `asynccontextmanager` adds one frame
|
|
373
385
|
if frame := inspect.currentframe(): # pragma: no branch
|
|
374
386
|
self._infer_name(frame.f_back)
|
|
375
|
-
model_used
|
|
387
|
+
model_used = await self._get_model(model)
|
|
376
388
|
|
|
377
389
|
deps = self._get_deps(deps)
|
|
378
390
|
new_message_index = len(message_history) if message_history else 0
|
|
@@ -381,32 +393,29 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
381
393
|
'{agent_name} run stream {prompt=}',
|
|
382
394
|
prompt=user_prompt,
|
|
383
395
|
agent=self,
|
|
384
|
-
mode_selection=mode_selection,
|
|
385
396
|
model_name=model_used.name(),
|
|
386
397
|
agent_name=self.name or 'agent',
|
|
387
398
|
) as run_span:
|
|
388
|
-
run_context = RunContext(deps,
|
|
399
|
+
run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
|
|
389
400
|
messages = await self._prepare_messages(user_prompt, message_history, run_context)
|
|
390
401
|
run_context.messages = messages
|
|
391
402
|
|
|
392
403
|
for tool in self._function_tools.values():
|
|
393
404
|
tool.current_retry = 0
|
|
394
405
|
|
|
395
|
-
usage = result.Usage()
|
|
396
406
|
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
397
407
|
usage_limits = usage_limits or UsageLimits()
|
|
398
408
|
|
|
399
|
-
run_step = 0
|
|
400
409
|
while True:
|
|
401
|
-
run_step += 1
|
|
402
|
-
usage_limits.check_before_request(usage)
|
|
410
|
+
run_context.run_step += 1
|
|
411
|
+
usage_limits.check_before_request(run_context.usage)
|
|
403
412
|
|
|
404
|
-
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
413
|
+
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
|
|
405
414
|
agent_model = await self._prepare_model(run_context)
|
|
406
415
|
|
|
407
|
-
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
|
|
416
|
+
with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span:
|
|
408
417
|
async with agent_model.request_stream(messages, model_settings) as model_response:
|
|
409
|
-
usage.requests += 1
|
|
418
|
+
run_context.usage.requests += 1
|
|
410
419
|
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
|
|
411
420
|
# We want to end the "model request" span here, but we can't exit the context manager
|
|
412
421
|
# in the traditional way
|
|
@@ -442,7 +451,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
442
451
|
yield result.StreamedRunResult(
|
|
443
452
|
messages,
|
|
444
453
|
new_message_index,
|
|
445
|
-
usage,
|
|
446
454
|
usage_limits,
|
|
447
455
|
result_stream,
|
|
448
456
|
self._result_schema,
|
|
@@ -466,8 +474,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
466
474
|
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
467
475
|
# the model_response should have been fully streamed by now, we can add its usage
|
|
468
476
|
model_response_usage = model_response.usage()
|
|
469
|
-
usage
|
|
470
|
-
usage_limits.check_tokens(usage)
|
|
477
|
+
run_context.usage.incr(model_response_usage)
|
|
478
|
+
usage_limits.check_tokens(run_context.usage)
|
|
471
479
|
|
|
472
480
|
@contextmanager
|
|
473
481
|
def override(
|
|
@@ -778,14 +786,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
778
786
|
|
|
779
787
|
self._function_tools[tool.name] = tool
|
|
780
788
|
|
|
781
|
-
async def _get_model(self, model: models.Model | models.KnownModelName | None) ->
|
|
789
|
+
async def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model:
|
|
782
790
|
"""Create a model configured for this agent.
|
|
783
791
|
|
|
784
792
|
Args:
|
|
785
793
|
model: model to use for this run, required if `model` was not set when creating the agent.
|
|
786
794
|
|
|
787
795
|
Returns:
|
|
788
|
-
|
|
796
|
+
The model used
|
|
789
797
|
"""
|
|
790
798
|
model_: models.Model
|
|
791
799
|
if some_model := self._override_model:
|
|
@@ -796,18 +804,15 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
796
804
|
'(Even when `override(model=...)` is customizing the model that will actually be called)'
|
|
797
805
|
)
|
|
798
806
|
model_ = some_model.value
|
|
799
|
-
mode_selection = 'override-model'
|
|
800
807
|
elif model is not None:
|
|
801
808
|
model_ = models.infer_model(model)
|
|
802
|
-
mode_selection = 'custom'
|
|
803
809
|
elif self.model is not None:
|
|
804
810
|
# noinspection PyTypeChecker
|
|
805
811
|
model_ = self.model = models.infer_model(self.model)
|
|
806
|
-
mode_selection = 'from-agent'
|
|
807
812
|
else:
|
|
808
813
|
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
|
|
809
814
|
|
|
810
|
-
return model_
|
|
815
|
+
return model_
|
|
811
816
|
|
|
812
817
|
async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
|
|
813
818
|
"""Build tools and create an agent model."""
|
|
@@ -830,15 +835,15 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
830
835
|
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
|
|
831
836
|
) -> list[_messages.ModelMessage]:
|
|
832
837
|
try:
|
|
833
|
-
|
|
838
|
+
ctx_messages = _messages_ctx_var.get()
|
|
834
839
|
except LookupError:
|
|
835
|
-
messages = []
|
|
840
|
+
messages: list[_messages.ModelMessage] = []
|
|
836
841
|
else:
|
|
837
|
-
if
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
+
if ctx_messages.used:
|
|
843
|
+
messages = []
|
|
844
|
+
else:
|
|
845
|
+
messages = ctx_messages.messages
|
|
846
|
+
ctx_messages.used = True
|
|
842
847
|
|
|
843
848
|
if message_history:
|
|
844
849
|
# shallow copy messages
|
|
@@ -1132,7 +1137,13 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1132
1137
|
raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
|
|
1133
1138
|
|
|
1134
1139
|
|
|
1135
|
-
|
|
1140
|
+
@dataclasses.dataclass
|
|
1141
|
+
class _RunMessages:
|
|
1142
|
+
messages: list[_messages.ModelMessage]
|
|
1143
|
+
used: bool = False
|
|
1144
|
+
|
|
1145
|
+
|
|
1146
|
+
_messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var')
|
|
1136
1147
|
|
|
1137
1148
|
|
|
1138
1149
|
@contextmanager
|
|
@@ -1156,21 +1167,21 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
|
|
|
1156
1167
|
```
|
|
1157
1168
|
|
|
1158
1169
|
!!! note
|
|
1159
|
-
|
|
1160
|
-
|
|
1170
|
+
If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
|
|
1171
|
+
`messages` will represent the messages exchanged during the first call only.
|
|
1161
1172
|
"""
|
|
1162
1173
|
try:
|
|
1163
|
-
yield _messages_ctx_var.get()
|
|
1174
|
+
yield _messages_ctx_var.get().messages
|
|
1164
1175
|
except LookupError:
|
|
1165
1176
|
messages: list[_messages.ModelMessage] = []
|
|
1166
|
-
token = _messages_ctx_var.set(messages)
|
|
1177
|
+
token = _messages_ctx_var.set(_RunMessages(messages))
|
|
1167
1178
|
try:
|
|
1168
1179
|
yield messages
|
|
1169
1180
|
finally:
|
|
1170
1181
|
_messages_ctx_var.reset(token)
|
|
1171
1182
|
|
|
1172
1183
|
|
|
1173
|
-
@dataclass
|
|
1184
|
+
@dataclasses.dataclass
|
|
1174
1185
|
class _MarkFinalResult(Generic[ResultData]):
|
|
1175
1186
|
"""Marker class to indicate that the result is the final result.
|
|
1176
1187
|
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -444,7 +444,8 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent:
|
|
|
444
444
|
if isinstance(item, ToolCallPart):
|
|
445
445
|
parts.append(_function_call_part_from_call(item))
|
|
446
446
|
elif isinstance(item, TextPart):
|
|
447
|
-
|
|
447
|
+
if item.content:
|
|
448
|
+
parts.append(_GeminiTextPart(text=item.content))
|
|
448
449
|
else:
|
|
449
450
|
assert_never(item)
|
|
450
451
|
return _GeminiContent(role='model', parts=parts)
|
|
@@ -701,7 +702,7 @@ class _GeminiJsonSchema:
|
|
|
701
702
|
|
|
702
703
|
def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
|
|
703
704
|
schema.pop('title', None)
|
|
704
|
-
schema.pop('default',
|
|
705
|
+
default = schema.pop('default', _utils.UNSET)
|
|
705
706
|
if ref := schema.pop('$ref', None):
|
|
706
707
|
# noinspection PyTypeChecker
|
|
707
708
|
key = re.sub(r'^#/\$defs/', '', ref)
|
|
@@ -714,8 +715,14 @@ class _GeminiJsonSchema:
|
|
|
714
715
|
return
|
|
715
716
|
|
|
716
717
|
if any_of := schema.get('anyOf'):
|
|
717
|
-
for
|
|
718
|
-
self._simplify(
|
|
718
|
+
for item_schema in any_of:
|
|
719
|
+
self._simplify(item_schema, refs_stack)
|
|
720
|
+
if len(any_of) == 2 and {'type': 'null'} in any_of and default is None:
|
|
721
|
+
for item_schema in any_of:
|
|
722
|
+
if item_schema != {'type': 'null'}:
|
|
723
|
+
schema.clear()
|
|
724
|
+
schema.update(item_schema)
|
|
725
|
+
return
|
|
719
726
|
|
|
720
727
|
type_ = schema.get('type')
|
|
721
728
|
|
pydantic_ai/models/ollama.py
CHANGED
|
@@ -71,6 +71,7 @@ class OllamaModel(Model):
|
|
|
71
71
|
model_name: OllamaModelName,
|
|
72
72
|
*,
|
|
73
73
|
base_url: str | None = 'http://localhost:11434/v1/',
|
|
74
|
+
api_key: str = 'ollama',
|
|
74
75
|
openai_client: AsyncOpenAI | None = None,
|
|
75
76
|
http_client: AsyncHTTPClient | None = None,
|
|
76
77
|
):
|
|
@@ -83,6 +84,8 @@ class OllamaModel(Model):
|
|
|
83
84
|
model_name: The name of the Ollama model to use. List of models available [here](https://ollama.com/library)
|
|
84
85
|
You must first download the model (`ollama pull <MODEL-NAME>`) in order to use the model
|
|
85
86
|
base_url: The base url for the ollama requests. The default value is the ollama default
|
|
87
|
+
api_key: The API key to use for authentication. Defaults to 'ollama' for local instances,
|
|
88
|
+
but can be customized for proxy setups that require authentication
|
|
86
89
|
openai_client: An existing
|
|
87
90
|
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
88
91
|
client to use, if provided, `base_url` and `http_client` must be `None`.
|
|
@@ -96,7 +99,7 @@ class OllamaModel(Model):
|
|
|
96
99
|
else:
|
|
97
100
|
# API key is not required for ollama but a value is required to create the client
|
|
98
101
|
http_client_ = http_client or cached_async_http_client()
|
|
99
|
-
oai_client = AsyncOpenAI(base_url=base_url, api_key=
|
|
102
|
+
oai_client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client_)
|
|
100
103
|
self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client)
|
|
101
104
|
|
|
102
105
|
async def agent_model(
|
pydantic_ai/models/test.py
CHANGED
|
@@ -16,6 +16,7 @@ from ..messages import (
|
|
|
16
16
|
ModelMessage,
|
|
17
17
|
ModelRequest,
|
|
18
18
|
ModelResponse,
|
|
19
|
+
ModelResponsePart,
|
|
19
20
|
RetryPromptPart,
|
|
20
21
|
TextPart,
|
|
21
22
|
ToolCallPart,
|
|
@@ -177,13 +178,23 @@ class TestAgentModel(AgentModel):
|
|
|
177
178
|
# check if there are any retry prompts, if so retry them
|
|
178
179
|
new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
|
|
179
180
|
if new_retry_names:
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
181
|
+
# Handle retries for both function tools and result tools
|
|
182
|
+
# Check function tools first
|
|
183
|
+
retry_parts: list[ModelResponsePart] = [
|
|
184
|
+
ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
|
|
185
|
+
for name, args in self.tool_calls
|
|
186
|
+
if name in new_retry_names
|
|
187
|
+
]
|
|
188
|
+
# Check result tools
|
|
189
|
+
if self.result_tools:
|
|
190
|
+
retry_parts.extend(
|
|
191
|
+
[
|
|
192
|
+
ToolCallPart.from_raw_args(tool.name, self.gen_tool_args(tool))
|
|
193
|
+
for tool in self.result_tools
|
|
194
|
+
if tool.name in new_retry_names
|
|
195
|
+
]
|
|
196
|
+
)
|
|
197
|
+
return ModelResponse(parts=retry_parts)
|
|
187
198
|
|
|
188
199
|
if response_text := self.result.left:
|
|
189
200
|
if response_text.value is None:
|
pydantic_ai/result.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
5
|
+
from copy import copy
|
|
5
6
|
from dataclasses import dataclass, field
|
|
6
7
|
from datetime import datetime
|
|
7
8
|
from typing import Generic, Union, cast
|
|
@@ -63,25 +64,33 @@ class Usage:
|
|
|
63
64
|
details: dict[str, int] | None = None
|
|
64
65
|
"""Any extra details returned by the model."""
|
|
65
66
|
|
|
66
|
-
def
|
|
67
|
-
"""
|
|
67
|
+
def incr(self, incr_usage: Usage, *, requests: int = 0) -> None:
|
|
68
|
+
"""Increment the usage in place.
|
|
68
69
|
|
|
69
|
-
|
|
70
|
+
Args:
|
|
71
|
+
incr_usage: The usage to increment by.
|
|
72
|
+
requests: The number of requests to increment by in addition to `incr_usage.requests`.
|
|
70
73
|
"""
|
|
71
|
-
|
|
74
|
+
self.requests += requests
|
|
72
75
|
for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
|
|
73
76
|
self_value = getattr(self, f)
|
|
74
|
-
other_value = getattr(
|
|
77
|
+
other_value = getattr(incr_usage, f)
|
|
75
78
|
if self_value is not None or other_value is not None:
|
|
76
|
-
|
|
79
|
+
setattr(self, f, (self_value or 0) + (other_value or 0))
|
|
77
80
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
details[key] = details.get(key, 0) + value
|
|
81
|
+
if incr_usage.details:
|
|
82
|
+
self.details = self.details or {}
|
|
83
|
+
for key, value in incr_usage.details.items():
|
|
84
|
+
self.details[key] = self.details.get(key, 0) + value
|
|
83
85
|
|
|
84
|
-
|
|
86
|
+
def __add__(self, other: Usage) -> Usage:
|
|
87
|
+
"""Add two Usages together.
|
|
88
|
+
|
|
89
|
+
This is provided so it's trivial to sum usage information from multiple requests and runs.
|
|
90
|
+
"""
|
|
91
|
+
new_usage = copy(self)
|
|
92
|
+
new_usage.incr(other)
|
|
93
|
+
return new_usage
|
|
85
94
|
|
|
86
95
|
|
|
87
96
|
@dataclass
|
|
@@ -136,8 +145,6 @@ class RunResult(_BaseRunResult[ResultData]):
|
|
|
136
145
|
class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
|
|
137
146
|
"""Result of a streamed run that returns structured data via a tool call."""
|
|
138
147
|
|
|
139
|
-
usage_so_far: Usage
|
|
140
|
-
"""Usage of the run up until the last request."""
|
|
141
148
|
_usage_limits: UsageLimits | None
|
|
142
149
|
_stream_response: models.EitherStreamedResponse
|
|
143
150
|
_result_schema: _result.ResultSchema[ResultData] | None
|
|
@@ -306,7 +313,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
306
313
|
!!! note
|
|
307
314
|
This won't return the full usage until the stream is finished.
|
|
308
315
|
"""
|
|
309
|
-
return self.
|
|
316
|
+
return self._run_ctx.usage + self._stream_response.usage()
|
|
310
317
|
|
|
311
318
|
def timestamp(self) -> datetime:
|
|
312
319
|
"""Get the timestamp of the response."""
|
pydantic_ai/settings.py
CHANGED
|
@@ -136,6 +136,6 @@ class UsageLimits:
|
|
|
136
136
|
f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})'
|
|
137
137
|
)
|
|
138
138
|
|
|
139
|
-
total_tokens =
|
|
139
|
+
total_tokens = usage.total_tokens or 0
|
|
140
140
|
if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
|
|
141
141
|
raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
|
pydantic_ai/tools.py
CHANGED
|
@@ -4,7 +4,7 @@ import dataclasses
|
|
|
4
4
|
import inspect
|
|
5
5
|
from collections.abc import Awaitable
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
|
-
from typing import Any, Callable, Generic, TypeVar, Union, cast
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast
|
|
8
8
|
|
|
9
9
|
from pydantic import ValidationError
|
|
10
10
|
from pydantic_core import SchemaValidator
|
|
@@ -13,6 +13,9 @@ from typing_extensions import Concatenate, ParamSpec, TypeAlias
|
|
|
13
13
|
from . import _pydantic, _utils, messages as _messages, models
|
|
14
14
|
from .exceptions import ModelRetry, UnexpectedModelBehavior
|
|
15
15
|
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from .result import Usage
|
|
18
|
+
|
|
16
19
|
__all__ = (
|
|
17
20
|
'AgentDeps',
|
|
18
21
|
'RunContext',
|
|
@@ -37,14 +40,20 @@ class RunContext(Generic[AgentDeps]):
|
|
|
37
40
|
|
|
38
41
|
deps: AgentDeps
|
|
39
42
|
"""Dependencies for the agent."""
|
|
40
|
-
retry: int
|
|
41
|
-
"""Number of retries so far."""
|
|
42
|
-
messages: list[_messages.ModelMessage]
|
|
43
|
-
"""Messages exchanged in the conversation so far."""
|
|
44
|
-
tool_name: str | None
|
|
45
|
-
"""Name of the tool being called."""
|
|
46
43
|
model: models.Model
|
|
47
44
|
"""The model used in this run."""
|
|
45
|
+
usage: Usage
|
|
46
|
+
"""LLM usage associated with the run."""
|
|
47
|
+
prompt: str
|
|
48
|
+
"""The original user prompt passed to the run."""
|
|
49
|
+
messages: list[_messages.ModelMessage] = field(default_factory=list)
|
|
50
|
+
"""Messages exchanged in the conversation so far."""
|
|
51
|
+
tool_name: str | None = None
|
|
52
|
+
"""Name of the tool being called."""
|
|
53
|
+
retry: int = 0
|
|
54
|
+
"""Number of retries so far."""
|
|
55
|
+
run_step: int = 0
|
|
56
|
+
"""The current step in the run."""
|
|
48
57
|
|
|
49
58
|
def replace_with(
|
|
50
59
|
self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
|
|
@@ -1,26 +1,26 @@
|
|
|
1
1
|
pydantic_ai/__init__.py,sha256=FbYetEgT6OO25u2KF5ZnFxKpz5DtnSpfckRXP4mjl8E,489
|
|
2
|
-
pydantic_ai/_griffe.py,sha256=
|
|
2
|
+
pydantic_ai/_griffe.py,sha256=Wqk3AuyeWuPwE5s1GbMeCsERelx1B4QcU9uYZSoko8s,3409
|
|
3
3
|
pydantic_ai/_pydantic.py,sha256=qXi5IsyiYOHeg_-qozCdxkfeqw2z0gBTjqgywBCiJWo,8125
|
|
4
4
|
pydantic_ai/_result.py,sha256=cUSugZQV0n5Z4fFHiMqua-2xs_0S6m-rr-yd6QS3nFE,10317
|
|
5
5
|
pydantic_ai/_system_prompt.py,sha256=MZJWksIoS5GM3Au5lznlcQnC-h7eqwtE7oI5WFgRcOg,1090
|
|
6
6
|
pydantic_ai/_utils.py,sha256=skWNgm89US_x1EpxdRy5wCkghBrm1XgxFCiEh6wAkAo,8753
|
|
7
|
-
pydantic_ai/agent.py,sha256=
|
|
7
|
+
pydantic_ai/agent.py,sha256=NJTcPSlqb4Fd-x9pDPuoXGCwFGF1GHcHevutoB0Busw,52333
|
|
8
8
|
pydantic_ai/exceptions.py,sha256=eGDKX6bGhgVxXBzu81Sk3iiAkXr0GUtgT7bD5Rxlqpg,2028
|
|
9
9
|
pydantic_ai/messages.py,sha256=ImbWY8Ft3mxInUQ08EmIWywf4nJBvTiJhmsECRYDkSQ,8968
|
|
10
10
|
pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
-
pydantic_ai/result.py,sha256=
|
|
12
|
-
pydantic_ai/settings.py,sha256=
|
|
13
|
-
pydantic_ai/tools.py,sha256=
|
|
11
|
+
pydantic_ai/result.py,sha256=LbZVHZnJnQwgegSz5PtwS9r_ifrJnLRpsa9xjYnHg1g,15549
|
|
12
|
+
pydantic_ai/settings.py,sha256=W8krcFsujjhE03rwckrz39F4Dz_9RwdBSeEF3izK0-Y,4918
|
|
13
|
+
pydantic_ai/tools.py,sha256=mnh3Lvs0Ri0FkqpV1MUooExNN4epTCcBKw6DyZvNSQ8,11745
|
|
14
14
|
pydantic_ai/models/__init__.py,sha256=XHt02IDQAircb-lEkIbIcuabSAIh5_UKnz2V1xN0Glw,10926
|
|
15
15
|
pydantic_ai/models/anthropic.py,sha256=EUZgmvT0jhMDbooBp_jfW0z2cM5jTMuAhVws1XKgaNs,13451
|
|
16
16
|
pydantic_ai/models/function.py,sha256=i7qkS_31aHrTbYVh6OzQ7Cwucz44F5PjT2EJK3GMphw,10573
|
|
17
|
-
pydantic_ai/models/gemini.py,sha256=
|
|
17
|
+
pydantic_ai/models/gemini.py,sha256=Sr19D2hN8iEAcoLlzv5883pto90TgEr_xiGlV8hMOwA,28572
|
|
18
18
|
pydantic_ai/models/groq.py,sha256=ZoPkuWJrf78JPnTRfZhi7v0ETgxJKNN5dH8BLWagGGk,15770
|
|
19
19
|
pydantic_ai/models/mistral.py,sha256=xGVI6-b8-9vnFickPPI2cRaHEWLc0jKKUM_vMjipf-U,25894
|
|
20
|
-
pydantic_ai/models/ollama.py,sha256=
|
|
20
|
+
pydantic_ai/models/ollama.py,sha256=ELqxhcNcnvQBnadd3gukS01zprUp6v8N_h1P5K-uf6c,4188
|
|
21
21
|
pydantic_ai/models/openai.py,sha256=qFFInL3NbgfGcsAWigxMP5mscp76hC-jJimHc9woU6Y,16518
|
|
22
|
-
pydantic_ai/models/test.py,sha256=
|
|
22
|
+
pydantic_ai/models/test.py,sha256=u2pdZd9OLXQ_jI6CaVt96udXuIcv0Hfnfqd3pFGmeJM,16514
|
|
23
23
|
pydantic_ai/models/vertexai.py,sha256=DBCBfpvpIhZaMG7cKvRl5rugCZqJqqEFm74uBc45weo,9259
|
|
24
|
-
pydantic_ai_slim-0.0.
|
|
25
|
-
pydantic_ai_slim-0.0.
|
|
26
|
-
pydantic_ai_slim-0.0.
|
|
24
|
+
pydantic_ai_slim-0.0.16.dist-info/METADATA,sha256=4udd7j2erIuMC0ekYgmgQAqsKfhA5sLsKzTcD_QyOeo,2730
|
|
25
|
+
pydantic_ai_slim-0.0.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
26
|
+
pydantic_ai_slim-0.0.16.dist-info/RECORD,,
|
|
File without changes
|