pydantic-ai-slim 0.0.30__tar.gz → 0.0.32__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/PKG-INFO +4 -4
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/__init__.py +2 -2
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/_agent_graph.py +86 -73
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/_result.py +3 -3
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/_utils.py +2 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/agent.py +54 -47
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/messages.py +55 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/__init__.py +4 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/anthropic.py +3 -1
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/gemini.py +1 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/instrumented.py +72 -101
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/result.py +27 -31
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pyproject.toml +4 -4
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/.gitignore +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/README.md +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/_pydantic.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/cohere.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/function.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/groq.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/mistral.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/openai.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/vertexai.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/tools.py +0 -0
- {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.32
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -28,11 +28,11 @@ Requires-Dist: eval-type-backport>=0.2.0
|
|
|
28
28
|
Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
|
-
Requires-Dist:
|
|
32
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
31
|
+
Requires-Dist: opentelemetry-api>=1.28.0
|
|
32
|
+
Requires-Dist: pydantic-graph==0.0.32
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
34
|
Provides-Extra: anthropic
|
|
35
|
-
Requires-Dist: anthropic>=0.
|
|
35
|
+
Requires-Dist: anthropic>=0.49.0; extra == 'anthropic'
|
|
36
36
|
Provides-Extra: cohere
|
|
37
37
|
Requires-Dist: cohere>=5.13.11; extra == 'cohere'
|
|
38
38
|
Provides-Extra: duckduckgo
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from importlib.metadata import version
|
|
2
2
|
|
|
3
|
-
from .agent import Agent,
|
|
3
|
+
from .agent import Agent, CallToolsNode, EndStrategy, ModelRequestNode, UserPromptNode, capture_run_messages
|
|
4
4
|
from .exceptions import (
|
|
5
5
|
AgentRunError,
|
|
6
6
|
FallbackExceptionGroup,
|
|
@@ -18,7 +18,7 @@ __all__ = (
|
|
|
18
18
|
# agent
|
|
19
19
|
'Agent',
|
|
20
20
|
'EndStrategy',
|
|
21
|
-
'
|
|
21
|
+
'CallToolsNode',
|
|
22
22
|
'ModelRequestNode',
|
|
23
23
|
'UserPromptNode',
|
|
24
24
|
'capture_run_messages',
|
|
@@ -2,13 +2,14 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import dataclasses
|
|
5
|
+
import json
|
|
5
6
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
6
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
7
8
|
from contextvars import ContextVar
|
|
8
9
|
from dataclasses import field
|
|
9
10
|
from typing import Any, Generic, Literal, Union, cast
|
|
10
11
|
|
|
11
|
-
import
|
|
12
|
+
from opentelemetry.trace import Span, Tracer
|
|
12
13
|
from typing_extensions import TypeGuard, TypeVar, assert_never
|
|
13
14
|
|
|
14
15
|
from pydantic_graph import BaseNode, Graph, GraphRunContext
|
|
@@ -23,6 +24,7 @@ from . import (
|
|
|
23
24
|
result,
|
|
24
25
|
usage as _usage,
|
|
25
26
|
)
|
|
27
|
+
from .models.instrumented import InstrumentedModel
|
|
26
28
|
from .result import ResultDataT
|
|
27
29
|
from .settings import ModelSettings, merge_model_settings
|
|
28
30
|
from .tools import (
|
|
@@ -36,22 +38,11 @@ __all__ = (
|
|
|
36
38
|
'GraphAgentDeps',
|
|
37
39
|
'UserPromptNode',
|
|
38
40
|
'ModelRequestNode',
|
|
39
|
-
'
|
|
41
|
+
'CallToolsNode',
|
|
40
42
|
'build_run_context',
|
|
41
43
|
'capture_run_messages',
|
|
42
44
|
)
|
|
43
45
|
|
|
44
|
-
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
45
|
-
|
|
46
|
-
# while waiting for https://github.com/pydantic/logfire/issues/745
|
|
47
|
-
try:
|
|
48
|
-
import logfire._internal.stack_info
|
|
49
|
-
except ImportError:
|
|
50
|
-
pass
|
|
51
|
-
else:
|
|
52
|
-
from pathlib import Path
|
|
53
|
-
|
|
54
|
-
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
|
|
55
46
|
|
|
56
47
|
T = TypeVar('T')
|
|
57
48
|
S = TypeVar('S')
|
|
@@ -104,7 +95,8 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
104
95
|
|
|
105
96
|
function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
|
|
106
97
|
|
|
107
|
-
run_span:
|
|
98
|
+
run_span: Span
|
|
99
|
+
tracer: Tracer
|
|
108
100
|
|
|
109
101
|
|
|
110
102
|
class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
|
|
@@ -243,12 +235,12 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
243
235
|
|
|
244
236
|
request: _messages.ModelRequest
|
|
245
237
|
|
|
246
|
-
_result:
|
|
238
|
+
_result: CallToolsNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
|
|
247
239
|
_did_stream: bool = field(default=False, repr=False)
|
|
248
240
|
|
|
249
241
|
async def run(
|
|
250
242
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
251
|
-
) ->
|
|
243
|
+
) -> CallToolsNode[DepsT, NodeRunEndT]:
|
|
252
244
|
if self._result is not None:
|
|
253
245
|
return self._result
|
|
254
246
|
|
|
@@ -286,39 +278,33 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
286
278
|
assert not self._did_stream, 'stream() should only be called once per node'
|
|
287
279
|
|
|
288
280
|
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
289
|
-
with
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
request_usage = streamed_response.usage()
|
|
302
|
-
span.set_attribute('response', model_response)
|
|
303
|
-
span.set_attribute('usage', request_usage)
|
|
281
|
+
async with ctx.deps.model.request_stream(
|
|
282
|
+
ctx.state.message_history, model_settings, model_request_parameters
|
|
283
|
+
) as streamed_response:
|
|
284
|
+
self._did_stream = True
|
|
285
|
+
ctx.state.usage.incr(_usage.Usage(), requests=1)
|
|
286
|
+
yield streamed_response
|
|
287
|
+
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
|
|
288
|
+
# otherwise usage won't be properly counted:
|
|
289
|
+
async for _ in streamed_response:
|
|
290
|
+
pass
|
|
291
|
+
model_response = streamed_response.get()
|
|
292
|
+
request_usage = streamed_response.usage()
|
|
304
293
|
|
|
305
294
|
self._finish_handling(ctx, model_response, request_usage)
|
|
306
295
|
assert self._result is not None # this should be set by the previous line
|
|
307
296
|
|
|
308
297
|
async def _make_request(
|
|
309
298
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
310
|
-
) ->
|
|
299
|
+
) -> CallToolsNode[DepsT, NodeRunEndT]:
|
|
311
300
|
if self._result is not None:
|
|
312
301
|
return self._result
|
|
313
302
|
|
|
314
303
|
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
ctx.state.usage.incr(_usage.Usage(), requests=1)
|
|
320
|
-
span.set_attribute('response', model_response)
|
|
321
|
-
span.set_attribute('usage', request_usage)
|
|
304
|
+
model_response, request_usage = await ctx.deps.model.request(
|
|
305
|
+
ctx.state.message_history, model_settings, model_request_parameters
|
|
306
|
+
)
|
|
307
|
+
ctx.state.usage.incr(_usage.Usage(), requests=1)
|
|
322
308
|
|
|
323
309
|
return self._finish_handling(ctx, model_response, request_usage)
|
|
324
310
|
|
|
@@ -335,7 +321,9 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
335
321
|
ctx.state.run_step += 1
|
|
336
322
|
|
|
337
323
|
model_settings = merge_model_settings(ctx.deps.model_settings, None)
|
|
338
|
-
with
|
|
324
|
+
with ctx.deps.tracer.start_as_current_span(
|
|
325
|
+
'preparing model request params', attributes=dict(run_step=ctx.state.run_step)
|
|
326
|
+
):
|
|
339
327
|
model_request_parameters = await _prepare_request_parameters(ctx)
|
|
340
328
|
return model_settings, model_request_parameters
|
|
341
329
|
|
|
@@ -344,7 +332,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
344
332
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
345
333
|
response: _messages.ModelResponse,
|
|
346
334
|
usage: _usage.Usage,
|
|
347
|
-
) ->
|
|
335
|
+
) -> CallToolsNode[DepsT, NodeRunEndT]:
|
|
348
336
|
# Update usage
|
|
349
337
|
ctx.state.usage.incr(usage, requests=0)
|
|
350
338
|
if ctx.deps.usage_limits:
|
|
@@ -354,13 +342,13 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
354
342
|
ctx.state.message_history.append(response)
|
|
355
343
|
|
|
356
344
|
# Set the `_result` attribute since we can't use `return` in an async iterator
|
|
357
|
-
self._result =
|
|
345
|
+
self._result = CallToolsNode(response)
|
|
358
346
|
|
|
359
347
|
return self._result
|
|
360
348
|
|
|
361
349
|
|
|
362
350
|
@dataclasses.dataclass
|
|
363
|
-
class
|
|
351
|
+
class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
364
352
|
"""Process a model response, and decide whether to end the run or make a new request."""
|
|
365
353
|
|
|
366
354
|
model_response: _messages.ModelResponse
|
|
@@ -385,26 +373,12 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
385
373
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
386
374
|
) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
|
|
387
375
|
"""Process the model response and yield events for the start and end of each function tool call."""
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
yield stream
|
|
391
|
-
|
|
392
|
-
# Run the stream to completion if it was not finished:
|
|
393
|
-
async for _event in stream:
|
|
394
|
-
pass
|
|
376
|
+
stream = self._run_stream(ctx)
|
|
377
|
+
yield stream
|
|
395
378
|
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
handle_span.set_attribute('result', next_node.data)
|
|
400
|
-
handle_span.message = 'handle model response -> final result'
|
|
401
|
-
elif tool_responses := self._tool_responses:
|
|
402
|
-
# TODO: We could drop `self._tool_responses` if we drop this set_attribute
|
|
403
|
-
# I'm thinking it might be better to just create a span for the handling of each tool
|
|
404
|
-
# than to set an attribute here.
|
|
405
|
-
handle_span.set_attribute('tool_responses', tool_responses)
|
|
406
|
-
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
407
|
-
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
379
|
+
# Run the stream to completion if it was not finished:
|
|
380
|
+
async for _event in stream:
|
|
381
|
+
pass
|
|
408
382
|
|
|
409
383
|
async def _run_stream(
|
|
410
384
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
@@ -454,8 +428,7 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
454
428
|
final_result: result.FinalResult[NodeRunEndT] | None = None
|
|
455
429
|
parts: list[_messages.ModelRequestPart] = []
|
|
456
430
|
if result_schema is not None:
|
|
457
|
-
|
|
458
|
-
call, result_tool = match
|
|
431
|
+
for call, result_tool in result_schema.find_tool(tool_calls):
|
|
459
432
|
try:
|
|
460
433
|
result_data = result_tool.validate(call)
|
|
461
434
|
result_data = await _validate_result(result_data, ctx, call)
|
|
@@ -465,12 +438,17 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
465
438
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
466
439
|
parts.append(e.tool_retry)
|
|
467
440
|
else:
|
|
468
|
-
final_result = result.FinalResult(result_data, call.tool_name)
|
|
441
|
+
final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
|
|
442
|
+
break
|
|
469
443
|
|
|
470
444
|
# Then build the other request parts based on end strategy
|
|
471
445
|
tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
|
|
472
446
|
async for event in process_function_tools(
|
|
473
|
-
tool_calls,
|
|
447
|
+
tool_calls,
|
|
448
|
+
final_result and final_result.tool_name,
|
|
449
|
+
final_result and final_result.tool_call_id,
|
|
450
|
+
ctx,
|
|
451
|
+
tool_responses,
|
|
474
452
|
):
|
|
475
453
|
yield event
|
|
476
454
|
|
|
@@ -495,8 +473,30 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
495
473
|
if tool_responses:
|
|
496
474
|
messages.append(_messages.ModelRequest(parts=tool_responses))
|
|
497
475
|
|
|
498
|
-
run_span.
|
|
499
|
-
|
|
476
|
+
run_span.set_attributes(
|
|
477
|
+
{
|
|
478
|
+
**usage.opentelemetry_attributes(),
|
|
479
|
+
'all_messages_events': json.dumps(
|
|
480
|
+
[InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)]
|
|
481
|
+
),
|
|
482
|
+
'final_result': final_result.data
|
|
483
|
+
if isinstance(final_result.data, str)
|
|
484
|
+
else json.dumps(InstrumentedModel.serialize_any(final_result.data)),
|
|
485
|
+
}
|
|
486
|
+
)
|
|
487
|
+
run_span.set_attributes(
|
|
488
|
+
{
|
|
489
|
+
'logfire.json_schema': json.dumps(
|
|
490
|
+
{
|
|
491
|
+
'type': 'object',
|
|
492
|
+
'properties': {
|
|
493
|
+
'all_messages_events': {'type': 'array'},
|
|
494
|
+
'final_result': {'type': 'object'},
|
|
495
|
+
},
|
|
496
|
+
}
|
|
497
|
+
),
|
|
498
|
+
}
|
|
499
|
+
)
|
|
500
500
|
|
|
501
501
|
# End the run with self.data
|
|
502
502
|
return End(final_result)
|
|
@@ -518,7 +518,7 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
518
518
|
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
|
|
519
519
|
else:
|
|
520
520
|
# The following cast is safe because we know `str` is an allowed result type
|
|
521
|
-
return self._handle_final_result(ctx, result.FinalResult(result_data,
|
|
521
|
+
return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
|
|
522
522
|
else:
|
|
523
523
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
524
524
|
return ModelRequestNode[DepsT, NodeRunEndT](
|
|
@@ -547,6 +547,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
|
|
|
547
547
|
async def process_function_tools(
|
|
548
548
|
tool_calls: list[_messages.ToolCallPart],
|
|
549
549
|
result_tool_name: str | None,
|
|
550
|
+
result_tool_call_id: str | None,
|
|
550
551
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
551
552
|
output_parts: list[_messages.ModelRequestPart],
|
|
552
553
|
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
@@ -566,7 +567,11 @@ async def process_function_tools(
|
|
|
566
567
|
calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
|
|
567
568
|
call_index_to_event_id: dict[int, str] = {}
|
|
568
569
|
for call in tool_calls:
|
|
569
|
-
if
|
|
570
|
+
if (
|
|
571
|
+
call.tool_name == result_tool_name
|
|
572
|
+
and call.tool_call_id == result_tool_call_id
|
|
573
|
+
and not found_used_result_tool
|
|
574
|
+
):
|
|
570
575
|
found_used_result_tool = True
|
|
571
576
|
output_parts.append(
|
|
572
577
|
_messages.ToolReturnPart(
|
|
@@ -593,9 +598,14 @@ async def process_function_tools(
|
|
|
593
598
|
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
|
|
594
599
|
# validation, we don't add another part here
|
|
595
600
|
if result_tool_name is not None:
|
|
601
|
+
if found_used_result_tool:
|
|
602
|
+
content = 'Result tool not used - a final result was already processed.'
|
|
603
|
+
else:
|
|
604
|
+
# TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
|
|
605
|
+
content = 'Result tool not used - result failed validation.'
|
|
596
606
|
part = _messages.ToolReturnPart(
|
|
597
607
|
tool_name=call.tool_name,
|
|
598
|
-
content=
|
|
608
|
+
content=content,
|
|
599
609
|
tool_call_id=call.tool_call_id,
|
|
600
610
|
)
|
|
601
611
|
output_parts.append(part)
|
|
@@ -607,7 +617,10 @@ async def process_function_tools(
|
|
|
607
617
|
|
|
608
618
|
# Run all tool tasks in parallel
|
|
609
619
|
results_by_index: dict[int, _messages.ModelRequestPart] = {}
|
|
610
|
-
|
|
620
|
+
tool_names = [call.tool_name for _, call in calls_to_run]
|
|
621
|
+
with ctx.deps.tracer.start_as_current_span(
|
|
622
|
+
'running tools', attributes={'tools': tool_names, 'logfire.msg': f'running tools: {", ".join(tool_names)}'}
|
|
623
|
+
):
|
|
611
624
|
# TODO: Should we wrap each individual tool call in a dedicated span?
|
|
612
625
|
tasks = [asyncio.create_task(tool.run(call, run_context), name=call.tool_name) for tool, call in calls_to_run]
|
|
613
626
|
pending = tasks
|
|
@@ -716,7 +729,7 @@ def build_agent_graph(
|
|
|
716
729
|
nodes = (
|
|
717
730
|
UserPromptNode[DepsT],
|
|
718
731
|
ModelRequestNode[DepsT],
|
|
719
|
-
|
|
732
|
+
CallToolsNode[DepsT],
|
|
720
733
|
)
|
|
721
734
|
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]](
|
|
722
735
|
nodes=nodes,
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import inspect
|
|
4
4
|
import sys
|
|
5
5
|
import types
|
|
6
|
-
from collections.abc import Awaitable, Iterable
|
|
6
|
+
from collections.abc import Awaitable, Iterable, Iterator
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
|
|
9
9
|
|
|
@@ -127,12 +127,12 @@ class ResultSchema(Generic[ResultDataT]):
|
|
|
127
127
|
def find_tool(
|
|
128
128
|
self,
|
|
129
129
|
parts: Iterable[_messages.ModelResponsePart],
|
|
130
|
-
) -> tuple[_messages.ToolCallPart, ResultTool[ResultDataT]]
|
|
130
|
+
) -> Iterator[tuple[_messages.ToolCallPart, ResultTool[ResultDataT]]]:
|
|
131
131
|
"""Find a tool that matches one of the calls."""
|
|
132
132
|
for part in parts:
|
|
133
133
|
if isinstance(part, _messages.ToolCallPart):
|
|
134
134
|
if result := self.tools.get(part.tool_name):
|
|
135
|
-
|
|
135
|
+
yield part, result
|
|
136
136
|
|
|
137
137
|
def tool_names(self) -> list[str]:
|
|
138
138
|
"""Return the names of the tools."""
|
|
@@ -48,6 +48,8 @@ def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
|
|
|
48
48
|
|
|
49
49
|
if schema.get('type') == 'object':
|
|
50
50
|
return schema
|
|
51
|
+
elif schema.get('$ref') is not None:
|
|
52
|
+
return schema.get('$defs', {}).get(schema['$ref'][8:]) # This removes the initial "#/$defs/".
|
|
51
53
|
else:
|
|
52
54
|
raise UserError('Schema must be an object')
|
|
53
55
|
|