pydantic-ai-slim 0.8.0__py3-none-any.whl → 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +28 -2
- pydantic_ai/_agent_graph.py +310 -140
- pydantic_ai/_function_schema.py +5 -5
- pydantic_ai/_griffe.py +2 -1
- pydantic_ai/_otel_messages.py +2 -2
- pydantic_ai/_output.py +31 -35
- pydantic_ai/_parts_manager.py +4 -4
- pydantic_ai/_run_context.py +3 -1
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_tool_manager.py +3 -22
- pydantic_ai/_utils.py +14 -26
- pydantic_ai/ag_ui.py +7 -8
- pydantic_ai/agent/__init__.py +84 -17
- pydantic_ai/agent/abstract.py +35 -4
- pydantic_ai/agent/wrapper.py +6 -0
- pydantic_ai/builtin_tools.py +2 -2
- pydantic_ai/common_tools/duckduckgo.py +4 -2
- pydantic_ai/durable_exec/temporal/__init__.py +70 -17
- pydantic_ai/durable_exec/temporal/_agent.py +23 -2
- pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
- pydantic_ai/durable_exec/temporal/_logfire.py +6 -3
- pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
- pydantic_ai/exceptions.py +45 -2
- pydantic_ai/format_prompt.py +2 -2
- pydantic_ai/mcp.py +2 -2
- pydantic_ai/messages.py +81 -28
- pydantic_ai/models/__init__.py +19 -7
- pydantic_ai/models/anthropic.py +6 -6
- pydantic_ai/models/bedrock.py +63 -57
- pydantic_ai/models/cohere.py +3 -3
- pydantic_ai/models/fallback.py +2 -2
- pydantic_ai/models/function.py +25 -23
- pydantic_ai/models/gemini.py +10 -13
- pydantic_ai/models/google.py +4 -4
- pydantic_ai/models/groq.py +5 -5
- pydantic_ai/models/huggingface.py +5 -5
- pydantic_ai/models/instrumented.py +44 -21
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +8 -8
- pydantic_ai/models/openai.py +20 -29
- pydantic_ai/models/test.py +24 -4
- pydantic_ai/output.py +27 -32
- pydantic_ai/profiles/__init__.py +3 -3
- pydantic_ai/profiles/groq.py +1 -1
- pydantic_ai/profiles/openai.py +25 -4
- pydantic_ai/providers/anthropic.py +2 -3
- pydantic_ai/providers/bedrock.py +3 -2
- pydantic_ai/result.py +173 -52
- pydantic_ai/retries.py +10 -29
- pydantic_ai/run.py +12 -5
- pydantic_ai/tools.py +126 -22
- pydantic_ai/toolsets/__init__.py +4 -1
- pydantic_ai/toolsets/_dynamic.py +4 -4
- pydantic_ai/toolsets/abstract.py +18 -2
- pydantic_ai/toolsets/approval_required.py +32 -0
- pydantic_ai/toolsets/combined.py +7 -12
- pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
- pydantic_ai/toolsets/filtered.py +1 -1
- pydantic_ai/toolsets/function.py +13 -4
- pydantic_ai/toolsets/wrapper.py +2 -1
- pydantic_ai/usage.py +7 -5
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/METADATA +6 -7
- pydantic_ai_slim-1.0.0b1.dist-info/RECORD +120 -0
- pydantic_ai_slim-0.8.0.dist-info/RECORD +0 -119
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -8,7 +8,7 @@ from contextlib import asynccontextmanager
|
|
|
8
8
|
from dataclasses import dataclass, field
|
|
9
9
|
from datetime import datetime
|
|
10
10
|
from itertools import count
|
|
11
|
-
from typing import TYPE_CHECKING, Any, Generic, Literal,
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
|
|
12
12
|
|
|
13
13
|
import anyio
|
|
14
14
|
import anyio.to_thread
|
|
@@ -125,7 +125,7 @@ LatestBedrockModelNames = Literal[
|
|
|
125
125
|
]
|
|
126
126
|
"""Latest Bedrock models."""
|
|
127
127
|
|
|
128
|
-
BedrockModelName =
|
|
128
|
+
BedrockModelName = str | LatestBedrockModelNames
|
|
129
129
|
"""Possible Bedrock model names.
|
|
130
130
|
|
|
131
131
|
Since Bedrock supports a variety of date-stamped models, we explicitly list the latest models but allow any name in the type hints.
|
|
@@ -301,9 +301,13 @@ class BedrockConverseModel(Model):
|
|
|
301
301
|
input_tokens=response['usage']['inputTokens'],
|
|
302
302
|
output_tokens=response['usage']['outputTokens'],
|
|
303
303
|
)
|
|
304
|
-
|
|
304
|
+
response_id = response.get('ResponseMetadata', {}).get('RequestId', None)
|
|
305
305
|
return ModelResponse(
|
|
306
|
-
items,
|
|
306
|
+
parts=items,
|
|
307
|
+
usage=u,
|
|
308
|
+
model_name=self.model_name,
|
|
309
|
+
provider_response_id=response_id,
|
|
310
|
+
provider_name=self._provider.name,
|
|
307
311
|
)
|
|
308
312
|
|
|
309
313
|
@overload
|
|
@@ -486,7 +490,7 @@ class BedrockConverseModel(Model):
|
|
|
486
490
|
else:
|
|
487
491
|
# NOTE: We don't pass the thinking part to Bedrock for models other than Claude since it raises an error.
|
|
488
492
|
pass
|
|
489
|
-
elif isinstance(item,
|
|
493
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart):
|
|
490
494
|
pass
|
|
491
495
|
else:
|
|
492
496
|
assert isinstance(item, ToolCallPart)
|
|
@@ -542,7 +546,7 @@ class BedrockConverseModel(Model):
|
|
|
542
546
|
content.append({'video': {'format': format, 'source': {'bytes': item.data}}})
|
|
543
547
|
else:
|
|
544
548
|
raise NotImplementedError('Binary content is not supported yet.')
|
|
545
|
-
elif isinstance(item,
|
|
549
|
+
elif isinstance(item, ImageUrl | DocumentUrl | VideoUrl):
|
|
546
550
|
downloaded_item = await download_item(item, data_format='bytes', type_format='extension')
|
|
547
551
|
format = downloaded_item['data_type']
|
|
548
552
|
if item.kind == 'image-url':
|
|
@@ -606,60 +610,62 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
606
610
|
chunk: ConverseStreamOutputTypeDef
|
|
607
611
|
tool_id: str | None = None
|
|
608
612
|
async for chunk in _AsyncIteratorWrapper(self._event_stream):
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
vendor_part_id=index,
|
|
627
|
-
tool_name=tool_name,
|
|
628
|
-
args=None,
|
|
629
|
-
tool_call_id=tool_id,
|
|
630
|
-
)
|
|
631
|
-
if maybe_event: # pragma: no branch
|
|
632
|
-
yield maybe_event
|
|
633
|
-
if 'contentBlockDelta' in chunk:
|
|
634
|
-
index = chunk['contentBlockDelta']['contentBlockIndex']
|
|
635
|
-
delta = chunk['contentBlockDelta']['delta']
|
|
636
|
-
if 'reasoningContent' in delta:
|
|
637
|
-
if text := delta['reasoningContent'].get('text'):
|
|
638
|
-
yield self._parts_manager.handle_thinking_delta(
|
|
613
|
+
match chunk:
|
|
614
|
+
case {'messageStart': _}:
|
|
615
|
+
continue
|
|
616
|
+
case {'messageStop': _}:
|
|
617
|
+
continue
|
|
618
|
+
case {'metadata': metadata}:
|
|
619
|
+
if 'usage' in metadata: # pragma: no branch
|
|
620
|
+
self._usage += self._map_usage(metadata)
|
|
621
|
+
continue
|
|
622
|
+
case {'contentBlockStart': content_block_start}:
|
|
623
|
+
index = content_block_start['contentBlockIndex']
|
|
624
|
+
start = content_block_start['start']
|
|
625
|
+
if 'toolUse' in start: # pragma: no branch
|
|
626
|
+
tool_use_start = start['toolUse']
|
|
627
|
+
tool_id = tool_use_start['toolUseId']
|
|
628
|
+
tool_name = tool_use_start['name']
|
|
629
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
639
630
|
vendor_part_id=index,
|
|
640
|
-
|
|
641
|
-
|
|
631
|
+
tool_name=tool_name,
|
|
632
|
+
args=None,
|
|
633
|
+
tool_call_id=tool_id,
|
|
642
634
|
)
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
635
|
+
if maybe_event: # pragma: no branch
|
|
636
|
+
yield maybe_event
|
|
637
|
+
case {'contentBlockDelta': content_block_delta}:
|
|
638
|
+
index = content_block_delta['contentBlockIndex']
|
|
639
|
+
delta = content_block_delta['delta']
|
|
640
|
+
if 'reasoningContent' in delta:
|
|
641
|
+
if text := delta['reasoningContent'].get('text'):
|
|
642
|
+
yield self._parts_manager.handle_thinking_delta(
|
|
643
|
+
vendor_part_id=index,
|
|
644
|
+
content=text,
|
|
645
|
+
signature=delta['reasoningContent'].get('signature'),
|
|
646
|
+
)
|
|
647
|
+
else: # pragma: no cover
|
|
648
|
+
warnings.warn(
|
|
649
|
+
f'Only text reasoning content is supported yet, but you got {delta["reasoningContent"]}. '
|
|
650
|
+
'Please report this to the maintainers.',
|
|
651
|
+
UserWarning,
|
|
652
|
+
)
|
|
653
|
+
if 'text' in delta:
|
|
654
|
+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
|
|
655
|
+
if maybe_event is not None: # pragma: no branch
|
|
656
|
+
yield maybe_event
|
|
657
|
+
if 'toolUse' in delta:
|
|
658
|
+
tool_use = delta['toolUse']
|
|
659
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
660
|
+
vendor_part_id=index,
|
|
661
|
+
tool_name=tool_use.get('name'),
|
|
662
|
+
args=tool_use.get('input'),
|
|
663
|
+
tool_call_id=tool_id,
|
|
648
664
|
)
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
if 'toolUse' in delta:
|
|
654
|
-
tool_use = delta['toolUse']
|
|
655
|
-
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
656
|
-
vendor_part_id=index,
|
|
657
|
-
tool_name=tool_use.get('name'),
|
|
658
|
-
args=tool_use.get('input'),
|
|
659
|
-
tool_call_id=tool_id,
|
|
660
|
-
)
|
|
661
|
-
if maybe_event: # pragma: no branch
|
|
662
|
-
yield maybe_event
|
|
665
|
+
if maybe_event: # pragma: no branch
|
|
666
|
+
yield maybe_event
|
|
667
|
+
case _:
|
|
668
|
+
pass # pyright wants match statements to be exhaustive
|
|
663
669
|
|
|
664
670
|
@property
|
|
665
671
|
def model_name(self) -> str:
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
from collections.abc import Iterable
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
|
-
from typing import Literal,
|
|
5
|
+
from typing import Literal, cast
|
|
6
6
|
|
|
7
7
|
from typing_extensions import assert_never
|
|
8
8
|
|
|
@@ -72,7 +72,7 @@ LatestCohereModelNames = Literal[
|
|
|
72
72
|
]
|
|
73
73
|
"""Latest Cohere models."""
|
|
74
74
|
|
|
75
|
-
CohereModelName =
|
|
75
|
+
CohereModelName = str | LatestCohereModelNames
|
|
76
76
|
"""Possible Cohere model names.
|
|
77
77
|
|
|
78
78
|
Since Cohere supports a variety of date-stamped models, we explicitly list the latest models but
|
|
@@ -228,7 +228,7 @@ class CohereModel(Model):
|
|
|
228
228
|
pass
|
|
229
229
|
elif isinstance(item, ToolCallPart):
|
|
230
230
|
tool_calls.append(self._map_tool_call(item))
|
|
231
|
-
elif isinstance(item,
|
|
231
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
232
232
|
# This is currently never returned from cohere
|
|
233
233
|
pass
|
|
234
234
|
else:
|
pydantic_ai/models/fallback.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator
|
|
3
|
+
from collections.abc import AsyncIterator, Callable
|
|
4
4
|
from contextlib import AsyncExitStack, asynccontextmanager, suppress
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
|
-
from typing import TYPE_CHECKING, Any
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
8
|
from opentelemetry.trace import get_current_span
|
|
9
9
|
|
pydantic_ai/models/function.py
CHANGED
|
@@ -2,14 +2,14 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator, Awaitable, Iterable, Sequence
|
|
5
|
+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
|
-
from dataclasses import dataclass, field
|
|
7
|
+
from dataclasses import KW_ONLY, dataclass, field
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from itertools import chain
|
|
10
|
-
from typing import Any,
|
|
10
|
+
from typing import Any, TypeAlias
|
|
11
11
|
|
|
12
|
-
from typing_extensions import
|
|
12
|
+
from typing_extensions import assert_never, overload
|
|
13
13
|
|
|
14
14
|
from .. import _utils, usage
|
|
15
15
|
from .._run_context import RunContext
|
|
@@ -44,8 +44,8 @@ class FunctionModel(Model):
|
|
|
44
44
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
45
45
|
"""
|
|
46
46
|
|
|
47
|
-
function: FunctionDef | None
|
|
48
|
-
stream_function: StreamFunctionDef | None
|
|
47
|
+
function: FunctionDef | None
|
|
48
|
+
stream_function: StreamFunctionDef | None
|
|
49
49
|
|
|
50
50
|
_model_name: str = field(repr=False)
|
|
51
51
|
_system: str = field(default='function', repr=False)
|
|
@@ -120,10 +120,10 @@ class FunctionModel(Model):
|
|
|
120
120
|
model_request_parameters: ModelRequestParameters,
|
|
121
121
|
) -> ModelResponse:
|
|
122
122
|
agent_info = AgentInfo(
|
|
123
|
-
model_request_parameters.function_tools,
|
|
124
|
-
model_request_parameters.allow_text_output,
|
|
125
|
-
model_request_parameters.output_tools,
|
|
126
|
-
model_settings,
|
|
123
|
+
function_tools=model_request_parameters.function_tools,
|
|
124
|
+
allow_text_output=model_request_parameters.allow_text_output,
|
|
125
|
+
output_tools=model_request_parameters.output_tools,
|
|
126
|
+
model_settings=model_settings,
|
|
127
127
|
)
|
|
128
128
|
|
|
129
129
|
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
@@ -149,10 +149,10 @@ class FunctionModel(Model):
|
|
|
149
149
|
run_context: RunContext[Any] | None = None,
|
|
150
150
|
) -> AsyncIterator[StreamedResponse]:
|
|
151
151
|
agent_info = AgentInfo(
|
|
152
|
-
model_request_parameters.function_tools,
|
|
153
|
-
model_request_parameters.allow_text_output,
|
|
154
|
-
model_request_parameters.output_tools,
|
|
155
|
-
model_settings,
|
|
152
|
+
function_tools=model_request_parameters.function_tools,
|
|
153
|
+
allow_text_output=model_request_parameters.allow_text_output,
|
|
154
|
+
output_tools=model_request_parameters.output_tools,
|
|
155
|
+
model_settings=model_settings,
|
|
156
156
|
)
|
|
157
157
|
|
|
158
158
|
assert self.stream_function is not None, (
|
|
@@ -182,7 +182,7 @@ class FunctionModel(Model):
|
|
|
182
182
|
return self._system
|
|
183
183
|
|
|
184
184
|
|
|
185
|
-
@dataclass(frozen=True)
|
|
185
|
+
@dataclass(frozen=True, kw_only=True)
|
|
186
186
|
class AgentInfo:
|
|
187
187
|
"""Information about an agent.
|
|
188
188
|
|
|
@@ -212,13 +212,17 @@ class DeltaToolCall:
|
|
|
212
212
|
|
|
213
213
|
name: str | None = None
|
|
214
214
|
"""Incremental change to the name of the tool."""
|
|
215
|
+
|
|
215
216
|
json_args: str | None = None
|
|
216
217
|
"""Incremental change to the arguments as JSON"""
|
|
218
|
+
|
|
219
|
+
_: KW_ONLY
|
|
220
|
+
|
|
217
221
|
tool_call_id: str | None = None
|
|
218
222
|
"""Incremental change to the tool call ID."""
|
|
219
223
|
|
|
220
224
|
|
|
221
|
-
@dataclass
|
|
225
|
+
@dataclass(kw_only=True)
|
|
222
226
|
class DeltaThinkingPart:
|
|
223
227
|
"""Incremental change to a thinking part.
|
|
224
228
|
|
|
@@ -237,18 +241,16 @@ DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
|
|
|
237
241
|
DeltaThinkingCalls: TypeAlias = dict[int, DeltaThinkingPart]
|
|
238
242
|
"""A mapping of thinking call IDs to incremental changes."""
|
|
239
243
|
|
|
240
|
-
|
|
241
|
-
FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
|
|
244
|
+
FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], ModelResponse | Awaitable[ModelResponse]]
|
|
242
245
|
"""A function used to generate a non-streamed response."""
|
|
243
246
|
|
|
244
|
-
# TODO: Change signature as indicated above
|
|
245
247
|
StreamFunctionDef: TypeAlias = Callable[
|
|
246
|
-
[list[ModelMessage], AgentInfo], AsyncIterator[
|
|
248
|
+
[list[ModelMessage], AgentInfo], AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]
|
|
247
249
|
]
|
|
248
250
|
"""A function used to generate a streamed response.
|
|
249
251
|
|
|
250
|
-
While this is defined as having return type of `AsyncIterator[
|
|
251
|
-
really be considered as `
|
|
252
|
+
While this is defined as having return type of `AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]`, it should
|
|
253
|
+
really be considered as `AsyncIterator[str] | AsyncIterator[DeltaToolCalls] | AsyncIterator[DeltaThinkingCalls]`,
|
|
252
254
|
|
|
253
255
|
E.g. you need to yield all text, all `DeltaToolCalls`, or all `DeltaThinkingCalls`, not mix them.
|
|
254
256
|
"""
|
|
@@ -326,7 +328,7 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage:
|
|
|
326
328
|
for message in messages:
|
|
327
329
|
if isinstance(message, ModelRequest):
|
|
328
330
|
for part in message.parts:
|
|
329
|
-
if isinstance(part,
|
|
331
|
+
if isinstance(part, SystemPromptPart | UserPromptPart):
|
|
330
332
|
request_tokens += _estimate_string_tokens(part.content)
|
|
331
333
|
elif isinstance(part, ToolReturnPart):
|
|
332
334
|
request_tokens += _estimate_string_tokens(part.model_response_str())
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Sequence
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import Annotated, Any, Literal, Protocol,
|
|
8
|
+
from typing import Annotated, Any, Literal, Protocol, cast
|
|
9
9
|
from uuid import uuid4
|
|
10
10
|
|
|
11
11
|
import httpx
|
|
@@ -51,7 +51,7 @@ LatestGeminiModelNames = Literal[
|
|
|
51
51
|
]
|
|
52
52
|
"""Latest Gemini models."""
|
|
53
53
|
|
|
54
|
-
GeminiModelName =
|
|
54
|
+
GeminiModelName = str | LatestGeminiModelNames
|
|
55
55
|
"""Possible Gemini model names.
|
|
56
56
|
|
|
57
57
|
Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
|
|
@@ -615,7 +615,7 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent:
|
|
|
615
615
|
elif isinstance(item, TextPart):
|
|
616
616
|
if item.content:
|
|
617
617
|
parts.append(_GeminiTextPart(text=item.content))
|
|
618
|
-
elif isinstance(item,
|
|
618
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
619
619
|
# This is currently never returned from gemini
|
|
620
620
|
pass
|
|
621
621
|
else:
|
|
@@ -690,7 +690,7 @@ def _process_response_from_parts(
|
|
|
690
690
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
691
691
|
)
|
|
692
692
|
return ModelResponse(
|
|
693
|
-
parts=items, usage=usage, model_name=model_name,
|
|
693
|
+
parts=items, usage=usage, model_name=model_name, provider_response_id=vendor_id, provider_details=vendor_details
|
|
694
694
|
)
|
|
695
695
|
|
|
696
696
|
|
|
@@ -735,16 +735,13 @@ def _part_discriminator(v: Any) -> str:
|
|
|
735
735
|
|
|
736
736
|
# See <https://ai.google.dev/api/caching#Part>
|
|
737
737
|
# we don't currently support other part types
|
|
738
|
-
# TODO discriminator
|
|
739
738
|
_GeminiPartUnion = Annotated[
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
Annotated[_GeminiThoughtPart, pydantic.Tag('thought')],
|
|
747
|
-
],
|
|
739
|
+
Annotated[_GeminiTextPart, pydantic.Tag('text')]
|
|
740
|
+
| Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')]
|
|
741
|
+
| Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')]
|
|
742
|
+
| Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')]
|
|
743
|
+
| Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')]
|
|
744
|
+
| Annotated[_GeminiThoughtPart, pydantic.Tag('thought')],
|
|
748
745
|
pydantic.Discriminator(_part_discriminator),
|
|
749
746
|
]
|
|
750
747
|
|
pydantic_ai/models/google.py
CHANGED
|
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Awaitable
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import Any, Literal,
|
|
8
|
+
from typing import Any, Literal, cast, overload
|
|
9
9
|
from uuid import uuid4
|
|
10
10
|
|
|
11
11
|
from typing_extensions import assert_never
|
|
@@ -91,7 +91,7 @@ LatestGoogleModelNames = Literal[
|
|
|
91
91
|
]
|
|
92
92
|
"""Latest Gemini models."""
|
|
93
93
|
|
|
94
|
-
GoogleModelName =
|
|
94
|
+
GoogleModelName = str | LatestGoogleModelNames
|
|
95
95
|
"""Possible Gemini model names.
|
|
96
96
|
|
|
97
97
|
Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
|
|
@@ -349,7 +349,7 @@ class GoogleModel(Model):
|
|
|
349
349
|
'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
|
|
350
350
|
}
|
|
351
351
|
if timeout := model_settings.get('timeout'):
|
|
352
|
-
if isinstance(timeout,
|
|
352
|
+
if isinstance(timeout, int | float):
|
|
353
353
|
http_options['timeout'] = int(1000 * timeout)
|
|
354
354
|
else:
|
|
355
355
|
raise UserError('Google does not support setting ModelSettings.timeout to a httpx.Timeout')
|
|
@@ -648,7 +648,7 @@ def _process_response_from_parts(
|
|
|
648
648
|
parts=items,
|
|
649
649
|
model_name=model_name,
|
|
650
650
|
usage=usage,
|
|
651
|
-
|
|
651
|
+
provider_response_id=vendor_id,
|
|
652
652
|
provider_details=vendor_details,
|
|
653
653
|
provider_name=provider_name,
|
|
654
654
|
)
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import Any, Literal,
|
|
8
|
+
from typing import Any, Literal, cast, overload
|
|
9
9
|
|
|
10
10
|
from typing_extensions import assert_never
|
|
11
11
|
|
|
@@ -88,7 +88,7 @@ PreviewGroqModelNames = Literal[
|
|
|
88
88
|
]
|
|
89
89
|
"""Preview Groq models from <https://console.groq.com/docs/models#preview-models>."""
|
|
90
90
|
|
|
91
|
-
GroqModelName =
|
|
91
|
+
GroqModelName = str | ProductionGroqModelNames | PreviewGroqModelNames
|
|
92
92
|
"""Possible Groq model names.
|
|
93
93
|
|
|
94
94
|
Since Groq supports a variety of models and the list changes frequencly, we explicitly list the named models as of 2025-03-31
|
|
@@ -285,11 +285,11 @@ class GroqModel(Model):
|
|
|
285
285
|
for c in choice.message.tool_calls:
|
|
286
286
|
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
287
287
|
return ModelResponse(
|
|
288
|
-
items,
|
|
288
|
+
parts=items,
|
|
289
289
|
usage=_map_usage(response),
|
|
290
290
|
model_name=response.model,
|
|
291
291
|
timestamp=timestamp,
|
|
292
|
-
|
|
292
|
+
provider_response_id=response.id,
|
|
293
293
|
provider_name=self._provider.name,
|
|
294
294
|
)
|
|
295
295
|
|
|
@@ -347,7 +347,7 @@ class GroqModel(Model):
|
|
|
347
347
|
elif isinstance(item, ThinkingPart):
|
|
348
348
|
# Skip thinking parts when mapping to Groq messages
|
|
349
349
|
continue
|
|
350
|
-
elif isinstance(item,
|
|
350
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
351
351
|
# This is currently never returned from groq
|
|
352
352
|
pass
|
|
353
353
|
else:
|
|
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, AsyncIterator
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime, timezone
|
|
8
|
-
from typing import Any, Literal,
|
|
8
|
+
from typing import Any, Literal, cast, overload
|
|
9
9
|
|
|
10
10
|
from typing_extensions import assert_never
|
|
11
11
|
|
|
@@ -88,7 +88,7 @@ LatestHuggingFaceModelNames = Literal[
|
|
|
88
88
|
"""Latest Hugging Face models."""
|
|
89
89
|
|
|
90
90
|
|
|
91
|
-
HuggingFaceModelName =
|
|
91
|
+
HuggingFaceModelName = str | LatestHuggingFaceModelNames
|
|
92
92
|
"""Possible Hugging Face model names.
|
|
93
93
|
|
|
94
94
|
You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
|
|
@@ -267,11 +267,11 @@ class HuggingFaceModel(Model):
|
|
|
267
267
|
for c in tool_calls:
|
|
268
268
|
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
|
|
269
269
|
return ModelResponse(
|
|
270
|
-
items,
|
|
270
|
+
parts=items,
|
|
271
271
|
usage=_map_usage(response),
|
|
272
272
|
model_name=response.model,
|
|
273
273
|
timestamp=timestamp,
|
|
274
|
-
|
|
274
|
+
provider_response_id=response.id,
|
|
275
275
|
provider_name=self._provider.name,
|
|
276
276
|
)
|
|
277
277
|
|
|
@@ -320,7 +320,7 @@ class HuggingFaceModel(Model):
|
|
|
320
320
|
# please open an issue. The below code is the code to send thinking to the provider.
|
|
321
321
|
# texts.append(f'<think>\n{item.content}\n</think>')
|
|
322
322
|
pass
|
|
323
|
-
elif isinstance(item,
|
|
323
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
324
324
|
# This is currently never returned from huggingface
|
|
325
325
|
pass
|
|
326
326
|
else:
|
|
@@ -2,10 +2,11 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import itertools
|
|
4
4
|
import json
|
|
5
|
-
|
|
5
|
+
import warnings
|
|
6
|
+
from collections.abc import AsyncIterator, Callable, Iterator, Mapping
|
|
6
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
7
8
|
from dataclasses import dataclass, field
|
|
8
|
-
from typing import Any,
|
|
9
|
+
from typing import Any, Literal, cast
|
|
9
10
|
from urllib.parse import urlparse
|
|
10
11
|
|
|
11
12
|
from opentelemetry._events import (
|
|
@@ -93,36 +94,41 @@ class InstrumentationSettings:
|
|
|
93
94
|
def __init__(
|
|
94
95
|
self,
|
|
95
96
|
*,
|
|
96
|
-
event_mode: Literal['attributes', 'logs'] = 'attributes',
|
|
97
97
|
tracer_provider: TracerProvider | None = None,
|
|
98
98
|
meter_provider: MeterProvider | None = None,
|
|
99
|
-
event_logger_provider: EventLoggerProvider | None = None,
|
|
100
99
|
include_binary_content: bool = True,
|
|
101
100
|
include_content: bool = True,
|
|
102
|
-
version: Literal[1, 2] =
|
|
101
|
+
version: Literal[1, 2] = 2,
|
|
102
|
+
event_mode: Literal['attributes', 'logs'] = 'attributes',
|
|
103
|
+
event_logger_provider: EventLoggerProvider | None = None,
|
|
103
104
|
):
|
|
104
105
|
"""Create instrumentation options.
|
|
105
106
|
|
|
106
107
|
Args:
|
|
107
|
-
event_mode: The mode for emitting events. If `'attributes'`, events are attached to the span as attributes.
|
|
108
|
-
If `'logs'`, events are emitted as OpenTelemetry log-based events.
|
|
109
108
|
tracer_provider: The OpenTelemetry tracer provider to use.
|
|
110
109
|
If not provided, the global tracer provider is used.
|
|
111
110
|
Calling `logfire.configure()` sets the global tracer provider, so most users don't need this.
|
|
112
111
|
meter_provider: The OpenTelemetry meter provider to use.
|
|
113
112
|
If not provided, the global meter provider is used.
|
|
114
113
|
Calling `logfire.configure()` sets the global meter provider, so most users don't need this.
|
|
115
|
-
event_logger_provider: The OpenTelemetry event logger provider to use.
|
|
116
|
-
If not provided, the global event logger provider is used.
|
|
117
|
-
Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
|
|
118
|
-
This is only used if `event_mode='logs'`.
|
|
119
114
|
include_binary_content: Whether to include binary content in the instrumentation events.
|
|
120
115
|
include_content: Whether to include prompts, completions, and tool call arguments and responses
|
|
121
116
|
in the instrumentation events.
|
|
122
|
-
version: Version of the data format.
|
|
123
|
-
Version 1 is based on the legacy event-based OpenTelemetry GenAI spec
|
|
124
|
-
|
|
125
|
-
|
|
117
|
+
version: Version of the data format. This is unrelated to the Pydantic AI package version.
|
|
118
|
+
Version 1 is based on the legacy event-based OpenTelemetry GenAI spec
|
|
119
|
+
and will be removed in a future release.
|
|
120
|
+
The parameters `event_mode` and `event_logger_provider` are only relevant for version 1.
|
|
121
|
+
Version 2 uses the newer OpenTelemetry GenAI spec and stores messages in the following attributes:
|
|
122
|
+
- `gen_ai.system_instructions` for instructions passed to the agent.
|
|
123
|
+
- `gen_ai.input.messages` and `gen_ai.output.messages` on model request spans.
|
|
124
|
+
- `pydantic_ai.all_messages` on agent run spans.
|
|
125
|
+
event_mode: The mode for emitting events in version 1.
|
|
126
|
+
If `'attributes'`, events are attached to the span as attributes.
|
|
127
|
+
If `'logs'`, events are emitted as OpenTelemetry log-based events.
|
|
128
|
+
event_logger_provider: The OpenTelemetry event logger provider to use.
|
|
129
|
+
If not provided, the global event logger provider is used.
|
|
130
|
+
Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
|
|
131
|
+
This is only used if `event_mode='logs'` and `version=1`.
|
|
126
132
|
"""
|
|
127
133
|
from pydantic_ai import __version__
|
|
128
134
|
|
|
@@ -136,6 +142,14 @@ class InstrumentationSettings:
|
|
|
136
142
|
self.event_mode = event_mode
|
|
137
143
|
self.include_binary_content = include_binary_content
|
|
138
144
|
self.include_content = include_content
|
|
145
|
+
|
|
146
|
+
if event_mode == 'logs' and version != 1:
|
|
147
|
+
warnings.warn(
|
|
148
|
+
'event_mode is only relevant for version=1 which is deprecated and will be removed in a future release.',
|
|
149
|
+
stacklevel=2,
|
|
150
|
+
)
|
|
151
|
+
version = 1
|
|
152
|
+
|
|
139
153
|
self.version = version
|
|
140
154
|
|
|
141
155
|
# As specified in the OpenTelemetry GenAI metrics spec:
|
|
@@ -236,27 +250,36 @@ class InstrumentationSettings:
|
|
|
236
250
|
if response.provider_details and 'finish_reason' in response.provider_details:
|
|
237
251
|
output_message['finish_reason'] = response.provider_details['finish_reason']
|
|
238
252
|
instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage]
|
|
253
|
+
system_instructions_attributes = self.system_instructions_attributes(instructions)
|
|
239
254
|
attributes = {
|
|
240
255
|
'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)),
|
|
241
256
|
'gen_ai.output.messages': json.dumps([output_message]),
|
|
257
|
+
**system_instructions_attributes,
|
|
242
258
|
'logfire.json_schema': json.dumps(
|
|
243
259
|
{
|
|
244
260
|
'type': 'object',
|
|
245
261
|
'properties': {
|
|
246
262
|
'gen_ai.input.messages': {'type': 'array'},
|
|
247
263
|
'gen_ai.output.messages': {'type': 'array'},
|
|
248
|
-
**(
|
|
264
|
+
**(
|
|
265
|
+
{'gen_ai.system_instructions': {'type': 'array'}}
|
|
266
|
+
if system_instructions_attributes
|
|
267
|
+
else {}
|
|
268
|
+
),
|
|
249
269
|
'model_request_parameters': {'type': 'object'},
|
|
250
270
|
},
|
|
251
271
|
}
|
|
252
272
|
),
|
|
253
273
|
}
|
|
254
|
-
if instructions is not None:
|
|
255
|
-
attributes['gen_ai.system_instructions'] = json.dumps(
|
|
256
|
-
[_otel_messages.TextPart(type='text', content=instructions)]
|
|
257
|
-
)
|
|
258
274
|
span.set_attributes(attributes)
|
|
259
275
|
|
|
276
|
+
def system_instructions_attributes(self, instructions: str | None) -> dict[str, str]:
|
|
277
|
+
if instructions and self.include_content:
|
|
278
|
+
return {
|
|
279
|
+
'gen_ai.system_instructions': json.dumps([_otel_messages.TextPart(type='text', content=instructions)]),
|
|
280
|
+
}
|
|
281
|
+
return {}
|
|
282
|
+
|
|
260
283
|
def _emit_events(self, span: Span, events: list[Event]) -> None:
|
|
261
284
|
if self.event_mode == 'logs':
|
|
262
285
|
for event in events:
|
|
@@ -357,7 +380,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
357
380
|
|
|
358
381
|
if model_settings:
|
|
359
382
|
for key in MODEL_SETTING_ATTRIBUTES:
|
|
360
|
-
if isinstance(value := model_settings.get(key),
|
|
383
|
+
if isinstance(value := model_settings.get(key), float | int):
|
|
361
384
|
attributes[f'gen_ai.request.{key}'] = value
|
|
362
385
|
|
|
363
386
|
record_metrics: Callable[[], None] | None = None
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
from collections.abc import AsyncIterator
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
|
-
from dataclasses import dataclass
|
|
5
|
+
from dataclasses import KW_ONLY, dataclass
|
|
6
6
|
from typing import TYPE_CHECKING, Any, cast
|
|
7
7
|
|
|
8
8
|
from .. import _mcp, exceptions
|
|
@@ -36,6 +36,8 @@ class MCPSamplingModel(Model):
|
|
|
36
36
|
session: ServerSession
|
|
37
37
|
"""The MCP server session to use for sampling."""
|
|
38
38
|
|
|
39
|
+
_: KW_ONLY
|
|
40
|
+
|
|
39
41
|
default_max_tokens: int = 16_384
|
|
40
42
|
"""Default max tokens to use if not set in [`ModelSettings`][pydantic_ai.settings.ModelSettings.max_tokens].
|
|
41
43
|
|