pydantic-ai-slim 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +5 -2
- pydantic_ai/_agent_graph.py +40 -16
- pydantic_ai/_cli.py +7 -3
- pydantic_ai/_function_schema.py +1 -4
- pydantic_ai/_output.py +654 -159
- pydantic_ai/_run_context.py +56 -0
- pydantic_ai/_system_prompt.py +2 -1
- pydantic_ai/_utils.py +111 -1
- pydantic_ai/agent.py +57 -34
- pydantic_ai/messages.py +20 -11
- pydantic_ai/models/__init__.py +21 -2
- pydantic_ai/models/anthropic.py +7 -9
- pydantic_ai/models/function.py +21 -3
- pydantic_ai/models/gemini.py +27 -4
- pydantic_ai/models/google.py +29 -4
- pydantic_ai/models/instrumented.py +5 -1
- pydantic_ai/models/mistral.py +5 -1
- pydantic_ai/models/openai.py +70 -9
- pydantic_ai/models/test.py +1 -1
- pydantic_ai/models/wrapper.py +6 -0
- pydantic_ai/output.py +288 -0
- pydantic_ai/profiles/__init__.py +21 -0
- pydantic_ai/profiles/_json_schema.py +1 -1
- pydantic_ai/profiles/google.py +6 -2
- pydantic_ai/profiles/openai.py +5 -0
- pydantic_ai/result.py +52 -26
- pydantic_ai/tools.py +5 -49
- {pydantic_ai_slim-0.3.2.dist-info → pydantic_ai_slim-0.3.4.dist-info}/METADATA +4 -4
- {pydantic_ai_slim-0.3.2.dist-info → pydantic_ai_slim-0.3.4.dist-info}/RECORD +32 -30
- {pydantic_ai_slim-0.3.2.dist-info → pydantic_ai_slim-0.3.4.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.3.2.dist-info → pydantic_ai_slim-0.3.4.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.3.2.dist-info → pydantic_ai_slim-0.3.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from dataclasses import field
|
|
6
|
+
from typing import TYPE_CHECKING, Generic
|
|
7
|
+
|
|
8
|
+
from typing_extensions import TypeVar
|
|
9
|
+
|
|
10
|
+
from . import _utils, messages as _messages
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from .models import Model
|
|
14
|
+
from .result import Usage
|
|
15
|
+
|
|
16
|
+
AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
|
|
17
|
+
"""Type variable for agent dependencies."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclasses.dataclass(repr=False)
|
|
21
|
+
class RunContext(Generic[AgentDepsT]):
|
|
22
|
+
"""Information about the current call."""
|
|
23
|
+
|
|
24
|
+
deps: AgentDepsT
|
|
25
|
+
"""Dependencies for the agent."""
|
|
26
|
+
model: Model
|
|
27
|
+
"""The model used in this run."""
|
|
28
|
+
usage: Usage
|
|
29
|
+
"""LLM usage associated with the run."""
|
|
30
|
+
prompt: str | Sequence[_messages.UserContent] | None
|
|
31
|
+
"""The original user prompt passed to the run."""
|
|
32
|
+
messages: list[_messages.ModelMessage] = field(default_factory=list)
|
|
33
|
+
"""Messages exchanged in the conversation so far."""
|
|
34
|
+
tool_call_id: str | None = None
|
|
35
|
+
"""The ID of the tool call."""
|
|
36
|
+
tool_name: str | None = None
|
|
37
|
+
"""Name of the tool being called."""
|
|
38
|
+
retry: int = 0
|
|
39
|
+
"""Number of retries so far."""
|
|
40
|
+
run_step: int = 0
|
|
41
|
+
"""The current step in the run."""
|
|
42
|
+
|
|
43
|
+
def replace_with(
|
|
44
|
+
self,
|
|
45
|
+
retry: int | None = None,
|
|
46
|
+
tool_name: str | None | _utils.Unset = _utils.UNSET,
|
|
47
|
+
) -> RunContext[AgentDepsT]:
|
|
48
|
+
# Create a new `RunContext` a new `retry` value and `tool_name`.
|
|
49
|
+
kwargs = {}
|
|
50
|
+
if retry is not None:
|
|
51
|
+
kwargs['retry'] = retry
|
|
52
|
+
if tool_name is not _utils.UNSET: # pragma: no branch
|
|
53
|
+
kwargs['tool_name'] = tool_name
|
|
54
|
+
return dataclasses.replace(self, **kwargs)
|
|
55
|
+
|
|
56
|
+
__repr__ = _utils.dataclasses_no_defaults_repr
|
pydantic_ai/_system_prompt.py
CHANGED
|
@@ -6,7 +6,8 @@ from dataclasses import dataclass, field
|
|
|
6
6
|
from typing import Any, Callable, Generic, cast
|
|
7
7
|
|
|
8
8
|
from . import _utils
|
|
9
|
-
from .
|
|
9
|
+
from ._run_context import AgentDepsT, RunContext
|
|
10
|
+
from .tools import SystemPromptFunc
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
@dataclass
|
pydantic_ai/_utils.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import functools
|
|
5
5
|
import inspect
|
|
6
|
+
import re
|
|
6
7
|
import time
|
|
7
8
|
import uuid
|
|
8
9
|
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator
|
|
@@ -16,7 +17,17 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overlo
|
|
|
16
17
|
from anyio.to_thread import run_sync
|
|
17
18
|
from pydantic import BaseModel, TypeAdapter
|
|
18
19
|
from pydantic.json_schema import JsonSchemaValue
|
|
19
|
-
from typing_extensions import
|
|
20
|
+
from typing_extensions import (
|
|
21
|
+
ParamSpec,
|
|
22
|
+
TypeAlias,
|
|
23
|
+
TypeGuard,
|
|
24
|
+
TypeIs,
|
|
25
|
+
get_args,
|
|
26
|
+
get_origin,
|
|
27
|
+
is_typeddict,
|
|
28
|
+
)
|
|
29
|
+
from typing_inspection import typing_objects
|
|
30
|
+
from typing_inspection.introspection import is_union_origin
|
|
20
31
|
|
|
21
32
|
from pydantic_graph._utils import AbstractSpan
|
|
22
33
|
|
|
@@ -327,3 +338,102 @@ def is_async_callable(obj: Any) -> Any:
|
|
|
327
338
|
obj = obj.func
|
|
328
339
|
|
|
329
340
|
return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, str]) -> None:
|
|
344
|
+
"""Update $refs in a schema to use the new names from name_mapping."""
|
|
345
|
+
if '$ref' in s:
|
|
346
|
+
ref = s['$ref']
|
|
347
|
+
if ref.startswith('#/$defs/'): # pragma: no branch
|
|
348
|
+
original_name = ref[8:] # Remove '#/$defs/'
|
|
349
|
+
new_name = name_mapping.get(original_name, original_name)
|
|
350
|
+
s['$ref'] = f'#/$defs/{new_name}'
|
|
351
|
+
|
|
352
|
+
# Recursively update refs in properties
|
|
353
|
+
if 'properties' in s:
|
|
354
|
+
props: dict[str, dict[str, Any]] = s['properties']
|
|
355
|
+
for prop in props.values():
|
|
356
|
+
_update_mapped_json_schema_refs(prop, name_mapping)
|
|
357
|
+
|
|
358
|
+
# Handle arrays
|
|
359
|
+
if 'items' in s and isinstance(s['items'], dict):
|
|
360
|
+
items: dict[str, Any] = s['items']
|
|
361
|
+
_update_mapped_json_schema_refs(items, name_mapping)
|
|
362
|
+
if 'prefixItems' in s:
|
|
363
|
+
prefix_items: list[dict[str, Any]] = s['prefixItems']
|
|
364
|
+
for item in prefix_items:
|
|
365
|
+
_update_mapped_json_schema_refs(item, name_mapping)
|
|
366
|
+
|
|
367
|
+
# Handle unions
|
|
368
|
+
for union_type in ['anyOf', 'oneOf']:
|
|
369
|
+
if union_type in s:
|
|
370
|
+
union_items: list[dict[str, Any]] = s[union_type]
|
|
371
|
+
for item in union_items:
|
|
372
|
+
_update_mapped_json_schema_refs(item, name_mapping)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def merge_json_schema_defs(schemas: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], dict[str, dict[str, Any]]]:
|
|
376
|
+
"""Merges the `$defs` from different JSON schemas into a single deduplicated `$defs`, handling name collisions of `$defs` that are not the same, and rewrites `$ref`s to point to the new `$defs`.
|
|
377
|
+
|
|
378
|
+
Returns a tuple of the rewritten schemas and a dictionary of the new `$defs`.
|
|
379
|
+
"""
|
|
380
|
+
all_defs: dict[str, dict[str, Any]] = {}
|
|
381
|
+
rewritten_schemas: list[dict[str, Any]] = []
|
|
382
|
+
|
|
383
|
+
for schema in schemas:
|
|
384
|
+
if '$defs' not in schema:
|
|
385
|
+
rewritten_schemas.append(schema)
|
|
386
|
+
continue
|
|
387
|
+
|
|
388
|
+
schema = schema.copy()
|
|
389
|
+
defs = schema.pop('$defs', None)
|
|
390
|
+
schema_name_mapping: dict[str, str] = {}
|
|
391
|
+
|
|
392
|
+
# Process definitions and build mapping
|
|
393
|
+
for name, def_schema in defs.items():
|
|
394
|
+
if name not in all_defs:
|
|
395
|
+
all_defs[name] = def_schema
|
|
396
|
+
schema_name_mapping[name] = name
|
|
397
|
+
elif def_schema != all_defs[name]:
|
|
398
|
+
new_name = name
|
|
399
|
+
if title := schema.get('title'):
|
|
400
|
+
new_name = f'{title}_{name}'
|
|
401
|
+
|
|
402
|
+
i = 1
|
|
403
|
+
original_new_name = new_name
|
|
404
|
+
new_name = f'{new_name}_{i}'
|
|
405
|
+
while new_name in all_defs:
|
|
406
|
+
i += 1
|
|
407
|
+
new_name = f'{original_new_name}_{i}'
|
|
408
|
+
|
|
409
|
+
all_defs[new_name] = def_schema
|
|
410
|
+
schema_name_mapping[name] = new_name
|
|
411
|
+
|
|
412
|
+
_update_mapped_json_schema_refs(schema, schema_name_mapping)
|
|
413
|
+
rewritten_schemas.append(schema)
|
|
414
|
+
|
|
415
|
+
return rewritten_schemas, all_defs
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def strip_markdown_fences(text: str) -> str:
|
|
419
|
+
if text.startswith('{'):
|
|
420
|
+
return text
|
|
421
|
+
|
|
422
|
+
regex = r'```(?:\w+)?\n(\{.*\})\n```'
|
|
423
|
+
match = re.search(regex, text, re.DOTALL)
|
|
424
|
+
if match:
|
|
425
|
+
return match.group(1)
|
|
426
|
+
|
|
427
|
+
return text
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def get_union_args(tp: Any) -> tuple[Any, ...]:
|
|
431
|
+
"""Extract the arguments of a Union type if `tp` is a union, otherwise return an empty tuple."""
|
|
432
|
+
if typing_objects.is_typealiastype(tp):
|
|
433
|
+
tp = tp.__value__
|
|
434
|
+
|
|
435
|
+
origin = get_origin(tp)
|
|
436
|
+
if is_union_origin(origin):
|
|
437
|
+
return get_args(tp)
|
|
438
|
+
else:
|
|
439
|
+
return ()
|
pydantic_ai/agent.py
CHANGED
|
@@ -14,6 +14,7 @@ from opentelemetry.trace import NoOpTracer, use_span
|
|
|
14
14
|
from pydantic.json_schema import GenerateJsonSchema
|
|
15
15
|
from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
|
|
16
16
|
|
|
17
|
+
from pydantic_ai.profiles import ModelProfile
|
|
17
18
|
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
|
|
18
19
|
from pydantic_graph._utils import get_event_loop
|
|
19
20
|
|
|
@@ -30,7 +31,8 @@ from . import (
|
|
|
30
31
|
)
|
|
31
32
|
from ._agent_graph import HistoryProcessor
|
|
32
33
|
from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
|
|
33
|
-
from .
|
|
34
|
+
from .output import OutputDataT, OutputSpec
|
|
35
|
+
from .result import FinalResult, StreamedRunResult
|
|
34
36
|
from .settings import ModelSettings, merge_model_settings
|
|
35
37
|
from .tools import (
|
|
36
38
|
AgentDepsT,
|
|
@@ -91,7 +93,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
91
93
|
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
|
|
92
94
|
|
|
93
95
|
Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT]
|
|
94
|
-
and the
|
|
96
|
+
and the output type they return, [`OutputDataT`][pydantic_ai.output.OutputDataT].
|
|
95
97
|
|
|
96
98
|
By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
|
|
97
99
|
|
|
@@ -128,7 +130,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
128
130
|
be merged with this value, with the runtime argument taking priority.
|
|
129
131
|
"""
|
|
130
132
|
|
|
131
|
-
output_type:
|
|
133
|
+
output_type: OutputSpec[OutputDataT]
|
|
132
134
|
"""
|
|
133
135
|
The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`.
|
|
134
136
|
"""
|
|
@@ -141,7 +143,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
141
143
|
_deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
|
|
142
144
|
_deprecated_result_tool_name: str | None = dataclasses.field(repr=False)
|
|
143
145
|
_deprecated_result_tool_description: str | None = dataclasses.field(repr=False)
|
|
144
|
-
_output_schema: _output.
|
|
146
|
+
_output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False)
|
|
145
147
|
_output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False)
|
|
146
148
|
_instructions: str | None = dataclasses.field(repr=False)
|
|
147
149
|
_instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
|
|
@@ -163,7 +165,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
163
165
|
self,
|
|
164
166
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
165
167
|
*,
|
|
166
|
-
output_type:
|
|
168
|
+
output_type: OutputSpec[OutputDataT] = str,
|
|
167
169
|
instructions: str
|
|
168
170
|
| _system_prompt.SystemPromptFunc[AgentDepsT]
|
|
169
171
|
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
|
|
@@ -324,8 +326,14 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
324
326
|
warnings.warn('`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning)
|
|
325
327
|
output_retries = result_retries
|
|
326
328
|
|
|
329
|
+
default_output_mode = (
|
|
330
|
+
self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None
|
|
331
|
+
)
|
|
327
332
|
self._output_schema = _output.OutputSchema[OutputDataT].build(
|
|
328
|
-
output_type,
|
|
333
|
+
output_type,
|
|
334
|
+
default_mode=default_output_mode,
|
|
335
|
+
name=self._deprecated_result_tool_name,
|
|
336
|
+
description=self._deprecated_result_tool_description,
|
|
329
337
|
)
|
|
330
338
|
self._output_validators = []
|
|
331
339
|
|
|
@@ -382,7 +390,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
382
390
|
self,
|
|
383
391
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
384
392
|
*,
|
|
385
|
-
output_type:
|
|
393
|
+
output_type: OutputSpec[RunOutputDataT],
|
|
386
394
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
387
395
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
388
396
|
deps: AgentDepsT = None,
|
|
@@ -412,7 +420,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
412
420
|
self,
|
|
413
421
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
414
422
|
*,
|
|
415
|
-
output_type:
|
|
423
|
+
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
416
424
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
417
425
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
418
426
|
deps: AgentDepsT = None,
|
|
@@ -500,7 +508,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
500
508
|
self,
|
|
501
509
|
user_prompt: str | Sequence[_messages.UserContent] | None,
|
|
502
510
|
*,
|
|
503
|
-
output_type:
|
|
511
|
+
output_type: OutputSpec[RunOutputDataT],
|
|
504
512
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
505
513
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
506
514
|
deps: AgentDepsT = None,
|
|
@@ -532,7 +540,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
532
540
|
self,
|
|
533
541
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
534
542
|
*,
|
|
535
|
-
output_type:
|
|
543
|
+
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
536
544
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
537
545
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
538
546
|
deps: AgentDepsT = None,
|
|
@@ -631,7 +639,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
631
639
|
|
|
632
640
|
deps = self._get_deps(deps)
|
|
633
641
|
new_message_index = len(message_history) if message_history else 0
|
|
634
|
-
output_schema = self._prepare_output_schema(output_type)
|
|
642
|
+
output_schema = self._prepare_output_schema(output_type, model_used.profile)
|
|
635
643
|
|
|
636
644
|
output_type_ = output_type or self.output_type
|
|
637
645
|
|
|
@@ -674,14 +682,20 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
674
682
|
)
|
|
675
683
|
|
|
676
684
|
async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
|
|
677
|
-
|
|
678
|
-
|
|
685
|
+
parts = [
|
|
686
|
+
self._instructions,
|
|
687
|
+
*[await func.run(run_context) for func in self._instructions_functions],
|
|
688
|
+
]
|
|
679
689
|
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
instructions.
|
|
683
|
-
|
|
684
|
-
|
|
690
|
+
model_profile = model_used.profile
|
|
691
|
+
if isinstance(output_schema, _output.PromptedOutputSchema):
|
|
692
|
+
instructions = output_schema.instructions(model_profile.prompted_output_template)
|
|
693
|
+
parts.append(instructions)
|
|
694
|
+
|
|
695
|
+
parts = [p for p in parts if p]
|
|
696
|
+
if not parts:
|
|
697
|
+
return None
|
|
698
|
+
return '\n\n'.join(parts).strip()
|
|
685
699
|
|
|
686
700
|
# Copy the function tools so that retry state is agent-run-specific
|
|
687
701
|
# Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`.
|
|
@@ -705,6 +719,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
705
719
|
tracer=tracer,
|
|
706
720
|
prepare_tools=self._prepare_tools,
|
|
707
721
|
get_instructions=get_instructions,
|
|
722
|
+
instrumentation_settings=instrumentation_settings,
|
|
708
723
|
)
|
|
709
724
|
start_node = _agent_graph.UserPromptNode[AgentDepsT](
|
|
710
725
|
user_prompt=user_prompt,
|
|
@@ -779,7 +794,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
779
794
|
self,
|
|
780
795
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
781
796
|
*,
|
|
782
|
-
output_type:
|
|
797
|
+
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
783
798
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
784
799
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
785
800
|
deps: AgentDepsT = None,
|
|
@@ -809,7 +824,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
809
824
|
self,
|
|
810
825
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
811
826
|
*,
|
|
812
|
-
output_type:
|
|
827
|
+
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
813
828
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
814
829
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
815
830
|
deps: AgentDepsT = None,
|
|
@@ -892,7 +907,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
892
907
|
self,
|
|
893
908
|
user_prompt: str | Sequence[_messages.UserContent],
|
|
894
909
|
*,
|
|
895
|
-
output_type:
|
|
910
|
+
output_type: OutputSpec[RunOutputDataT],
|
|
896
911
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
897
912
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
898
913
|
deps: AgentDepsT = None,
|
|
@@ -923,7 +938,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
923
938
|
self,
|
|
924
939
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
925
940
|
*,
|
|
926
|
-
output_type:
|
|
941
|
+
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
927
942
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
928
943
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
929
944
|
deps: AgentDepsT = None,
|
|
@@ -1002,10 +1017,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1002
1017
|
async for maybe_part_event in streamed_response:
|
|
1003
1018
|
if isinstance(maybe_part_event, _messages.PartStartEvent):
|
|
1004
1019
|
new_part = maybe_part_event.part
|
|
1005
|
-
if isinstance(new_part, _messages.TextPart)
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1020
|
+
if isinstance(new_part, _messages.TextPart) and isinstance(
|
|
1021
|
+
output_schema, _output.TextOutputSchema
|
|
1022
|
+
):
|
|
1023
|
+
return FinalResult(s, None, None)
|
|
1024
|
+
elif isinstance(new_part, _messages.ToolCallPart) and isinstance(
|
|
1025
|
+
output_schema, _output.ToolOutputSchema
|
|
1026
|
+
): # pragma: no branch
|
|
1009
1027
|
for call, _ in output_schema.find_tool([new_part]):
|
|
1010
1028
|
return FinalResult(s, call.tool_name, call.tool_call_id)
|
|
1011
1029
|
return None
|
|
@@ -1561,8 +1579,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1561
1579
|
if tool.name in self._function_tools:
|
|
1562
1580
|
raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
|
|
1563
1581
|
|
|
1564
|
-
if
|
|
1565
|
-
raise exceptions.UserError(f'Tool name conflicts with
|
|
1582
|
+
if tool.name in self._output_schema.tools:
|
|
1583
|
+
raise exceptions.UserError(f'Tool name conflicts with output tool name: {tool.name!r}')
|
|
1566
1584
|
|
|
1567
1585
|
self._function_tools[tool.name] = tool
|
|
1568
1586
|
|
|
@@ -1637,18 +1655,23 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1637
1655
|
raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
|
|
1638
1656
|
|
|
1639
1657
|
def _prepare_output_schema(
|
|
1640
|
-
self, output_type:
|
|
1641
|
-
) -> _output.OutputSchema[RunOutputDataT]
|
|
1658
|
+
self, output_type: OutputSpec[RunOutputDataT] | None, model_profile: ModelProfile
|
|
1659
|
+
) -> _output.OutputSchema[RunOutputDataT]:
|
|
1642
1660
|
if output_type is not None:
|
|
1643
1661
|
if self._output_validators:
|
|
1644
1662
|
raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators')
|
|
1645
|
-
|
|
1663
|
+
schema = _output.OutputSchema[RunOutputDataT].build(
|
|
1646
1664
|
output_type,
|
|
1647
|
-
self._deprecated_result_tool_name,
|
|
1648
|
-
self._deprecated_result_tool_description,
|
|
1665
|
+
name=self._deprecated_result_tool_name,
|
|
1666
|
+
description=self._deprecated_result_tool_description,
|
|
1667
|
+
default_mode=model_profile.default_structured_output_mode,
|
|
1649
1668
|
)
|
|
1650
1669
|
else:
|
|
1651
|
-
|
|
1670
|
+
schema = self._output_schema.with_default_mode(model_profile.default_structured_output_mode)
|
|
1671
|
+
|
|
1672
|
+
schema.raise_if_unsupported(model_profile)
|
|
1673
|
+
|
|
1674
|
+
return schema # pyright: ignore[reportReturnType]
|
|
1652
1675
|
|
|
1653
1676
|
@staticmethod
|
|
1654
1677
|
def is_model_request_node(
|
pydantic_ai/messages.py
CHANGED
|
@@ -76,8 +76,11 @@ class SystemPromptPart:
|
|
|
76
76
|
part_kind: Literal['system-prompt'] = 'system-prompt'
|
|
77
77
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
78
78
|
|
|
79
|
-
def otel_event(self,
|
|
80
|
-
return Event(
|
|
79
|
+
def otel_event(self, settings: InstrumentationSettings) -> Event:
|
|
80
|
+
return Event(
|
|
81
|
+
'gen_ai.system.message',
|
|
82
|
+
body={'role': 'system', **({'content': self.content} if settings.include_content else {})},
|
|
83
|
+
)
|
|
81
84
|
|
|
82
85
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
83
86
|
|
|
@@ -362,12 +365,12 @@ class UserPromptPart:
|
|
|
362
365
|
content = []
|
|
363
366
|
for part in self.content:
|
|
364
367
|
if isinstance(part, str):
|
|
365
|
-
content.append(part)
|
|
368
|
+
content.append(part if settings.include_content else {'kind': 'text'})
|
|
366
369
|
elif isinstance(part, (ImageUrl, AudioUrl, DocumentUrl, VideoUrl)):
|
|
367
|
-
content.append({'kind': part.kind, 'url': part.url})
|
|
370
|
+
content.append({'kind': part.kind, **({'url': part.url} if settings.include_content else {})})
|
|
368
371
|
elif isinstance(part, BinaryContent):
|
|
369
372
|
converted_part = {'kind': part.kind, 'media_type': part.media_type}
|
|
370
|
-
if settings.include_binary_content:
|
|
373
|
+
if settings.include_content and settings.include_binary_content:
|
|
371
374
|
converted_part['binary_content'] = base64.b64encode(part.data).decode()
|
|
372
375
|
content.append(converted_part)
|
|
373
376
|
else:
|
|
@@ -414,10 +417,15 @@ class ToolReturnPart:
|
|
|
414
417
|
else:
|
|
415
418
|
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
|
|
416
419
|
|
|
417
|
-
def otel_event(self,
|
|
420
|
+
def otel_event(self, settings: InstrumentationSettings) -> Event:
|
|
418
421
|
return Event(
|
|
419
422
|
'gen_ai.tool.message',
|
|
420
|
-
body={
|
|
423
|
+
body={
|
|
424
|
+
**({'content': self.content} if settings.include_content else {}),
|
|
425
|
+
'role': 'tool',
|
|
426
|
+
'id': self.tool_call_id,
|
|
427
|
+
'name': self.tool_name,
|
|
428
|
+
},
|
|
421
429
|
)
|
|
422
430
|
|
|
423
431
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
@@ -473,14 +481,14 @@ class RetryPromptPart:
|
|
|
473
481
|
description = f'{len(self.content)} validation errors: {json_errors.decode()}'
|
|
474
482
|
return f'{description}\n\nFix the errors and try again.'
|
|
475
483
|
|
|
476
|
-
def otel_event(self,
|
|
484
|
+
def otel_event(self, settings: InstrumentationSettings) -> Event:
|
|
477
485
|
if self.tool_name is None:
|
|
478
486
|
return Event('gen_ai.user.message', body={'content': self.model_response(), 'role': 'user'})
|
|
479
487
|
else:
|
|
480
488
|
return Event(
|
|
481
489
|
'gen_ai.tool.message',
|
|
482
490
|
body={
|
|
483
|
-
'content': self.model_response(),
|
|
491
|
+
**({'content': self.model_response()} if settings.include_content else {}),
|
|
484
492
|
'role': 'tool',
|
|
485
493
|
'id': self.tool_call_id,
|
|
486
494
|
'name': self.tool_name,
|
|
@@ -657,7 +665,7 @@ class ModelResponse:
|
|
|
657
665
|
vendor_id: str | None = None
|
|
658
666
|
"""Vendor ID as specified by the model provider. This can be used to track the specific request to the model."""
|
|
659
667
|
|
|
660
|
-
def otel_events(self) -> list[Event]:
|
|
668
|
+
def otel_events(self, settings: InstrumentationSettings) -> list[Event]:
|
|
661
669
|
"""Return OpenTelemetry events for the response."""
|
|
662
670
|
result: list[Event] = []
|
|
663
671
|
|
|
@@ -683,7 +691,8 @@ class ModelResponse:
|
|
|
683
691
|
elif isinstance(part, TextPart):
|
|
684
692
|
if body.get('content'):
|
|
685
693
|
body = new_event_body()
|
|
686
|
-
|
|
694
|
+
if settings.include_content:
|
|
695
|
+
body['content'] = part.content
|
|
687
696
|
|
|
688
697
|
return result
|
|
689
698
|
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -20,9 +20,12 @@ from typing_extensions import Literal, TypeAliasType, TypedDict
|
|
|
20
20
|
|
|
21
21
|
from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
|
|
22
22
|
|
|
23
|
+
from .. import _utils
|
|
24
|
+
from .._output import OutputObjectDefinition
|
|
23
25
|
from .._parts_manager import ModelResponsePartsManager
|
|
24
26
|
from ..exceptions import UserError
|
|
25
27
|
from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl
|
|
28
|
+
from ..output import OutputMode
|
|
26
29
|
from ..profiles._json_schema import JsonSchemaTransformer
|
|
27
30
|
from ..settings import ModelSettings
|
|
28
31
|
from ..tools import ToolDefinition
|
|
@@ -300,13 +303,18 @@ KnownModelName = TypeAliasType(
|
|
|
300
303
|
"""
|
|
301
304
|
|
|
302
305
|
|
|
303
|
-
@dataclass
|
|
306
|
+
@dataclass(repr=False)
|
|
304
307
|
class ModelRequestParameters:
|
|
305
308
|
"""Configuration for an agent's request to a model, specifically related to tools and output handling."""
|
|
306
309
|
|
|
307
310
|
function_tools: list[ToolDefinition] = field(default_factory=list)
|
|
308
|
-
|
|
311
|
+
|
|
312
|
+
output_mode: OutputMode = 'text'
|
|
313
|
+
output_object: OutputObjectDefinition | None = None
|
|
309
314
|
output_tools: list[ToolDefinition] = field(default_factory=list)
|
|
315
|
+
allow_text_output: bool = True
|
|
316
|
+
|
|
317
|
+
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
310
318
|
|
|
311
319
|
|
|
312
320
|
class Model(ABC):
|
|
@@ -351,6 +359,11 @@ class Model(ABC):
|
|
|
351
359
|
function_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.function_tools],
|
|
352
360
|
output_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.output_tools],
|
|
353
361
|
)
|
|
362
|
+
if output_object := model_request_parameters.output_object:
|
|
363
|
+
model_request_parameters = replace(
|
|
364
|
+
model_request_parameters,
|
|
365
|
+
output_object=_customize_output_object(transformer, output_object),
|
|
366
|
+
)
|
|
354
367
|
|
|
355
368
|
return model_request_parameters
|
|
356
369
|
|
|
@@ -718,3 +731,9 @@ def _customize_tool_def(transformer: type[JsonSchemaTransformer], t: ToolDefinit
|
|
|
718
731
|
if t.strict is None:
|
|
719
732
|
t = replace(t, strict=schema_transformer.is_strict_compatible)
|
|
720
733
|
return replace(t, parameters_json_schema=parameters_json_schema)
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
def _customize_output_object(transformer: type[JsonSchemaTransformer], o: OutputObjectDefinition):
|
|
737
|
+
schema_transformer = transformer(o.json_schema, strict=True)
|
|
738
|
+
son_schema = schema_transformer.walk()
|
|
739
|
+
return replace(o, json_schema=son_schema)
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -342,15 +342,13 @@ class AnthropicModel(Model):
|
|
|
342
342
|
if response_part.content: # Only add non-empty text
|
|
343
343
|
assistant_content_params.append(BetaTextBlockParam(text=response_part.content, type='text'))
|
|
344
344
|
elif isinstance(response_part, ThinkingPart):
|
|
345
|
-
# NOTE: We
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
# )
|
|
353
|
-
pass
|
|
345
|
+
# NOTE: We only send thinking part back for Anthropic, otherwise they raise an error.
|
|
346
|
+
if response_part.signature is not None: # pragma: no branch
|
|
347
|
+
assistant_content_params.append(
|
|
348
|
+
BetaThinkingBlockParam(
|
|
349
|
+
thinking=response_part.content, signature=response_part.signature, type='thinking'
|
|
350
|
+
)
|
|
351
|
+
)
|
|
354
352
|
else:
|
|
355
353
|
tool_use_block_param = BetaToolUseBlockParam(
|
|
356
354
|
id=_guard_tool_call_id(t=response_part),
|
pydantic_ai/models/function.py
CHANGED
|
@@ -11,6 +11,8 @@ from typing import Callable, Union
|
|
|
11
11
|
|
|
12
12
|
from typing_extensions import TypeAlias, assert_never, overload
|
|
13
13
|
|
|
14
|
+
from pydantic_ai.profiles import ModelProfileSpec
|
|
15
|
+
|
|
14
16
|
from .. import _utils, usage
|
|
15
17
|
from .._utils import PeekableAsyncStream
|
|
16
18
|
from ..messages import (
|
|
@@ -49,14 +51,27 @@ class FunctionModel(Model):
|
|
|
49
51
|
_system: str = field(default='function', repr=False)
|
|
50
52
|
|
|
51
53
|
@overload
|
|
52
|
-
def __init__(
|
|
54
|
+
def __init__(
|
|
55
|
+
self, function: FunctionDef, *, model_name: str | None = None, profile: ModelProfileSpec | None = None
|
|
56
|
+
) -> None: ...
|
|
53
57
|
|
|
54
58
|
@overload
|
|
55
|
-
def __init__(
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
*,
|
|
62
|
+
stream_function: StreamFunctionDef,
|
|
63
|
+
model_name: str | None = None,
|
|
64
|
+
profile: ModelProfileSpec | None = None,
|
|
65
|
+
) -> None: ...
|
|
56
66
|
|
|
57
67
|
@overload
|
|
58
68
|
def __init__(
|
|
59
|
-
self,
|
|
69
|
+
self,
|
|
70
|
+
function: FunctionDef,
|
|
71
|
+
*,
|
|
72
|
+
stream_function: StreamFunctionDef,
|
|
73
|
+
model_name: str | None = None,
|
|
74
|
+
profile: ModelProfileSpec | None = None,
|
|
60
75
|
) -> None: ...
|
|
61
76
|
|
|
62
77
|
def __init__(
|
|
@@ -65,6 +80,7 @@ class FunctionModel(Model):
|
|
|
65
80
|
*,
|
|
66
81
|
stream_function: StreamFunctionDef | None = None,
|
|
67
82
|
model_name: str | None = None,
|
|
83
|
+
profile: ModelProfileSpec | None = None,
|
|
68
84
|
):
|
|
69
85
|
"""Initialize a `FunctionModel`.
|
|
70
86
|
|
|
@@ -74,6 +90,7 @@ class FunctionModel(Model):
|
|
|
74
90
|
function: The function to call for non-streamed requests.
|
|
75
91
|
stream_function: The function to call for streamed requests.
|
|
76
92
|
model_name: The name of the model. If not provided, a name is generated from the function names.
|
|
93
|
+
profile: The model profile to use.
|
|
77
94
|
"""
|
|
78
95
|
if function is None and stream_function is None:
|
|
79
96
|
raise TypeError('Either `function` or `stream_function` must be provided')
|
|
@@ -83,6 +100,7 @@ class FunctionModel(Model):
|
|
|
83
100
|
function_name = self.function.__name__ if self.function is not None else ''
|
|
84
101
|
stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
|
|
85
102
|
self._model_name = model_name or f'function:{function_name}:{stream_function_name}'
|
|
103
|
+
self._profile = profile
|
|
86
104
|
|
|
87
105
|
async def request(
|
|
88
106
|
self,
|