pydantic-ai-slim 1.0.1__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +50 -31
- pydantic_ai/_tool_manager.py +4 -0
- pydantic_ai/agent/__init__.py +3 -0
- pydantic_ai/durable_exec/dbos/__init__.py +6 -0
- pydantic_ai/durable_exec/dbos/_agent.py +718 -0
- pydantic_ai/durable_exec/dbos/_mcp_server.py +89 -0
- pydantic_ai/durable_exec/dbos/_model.py +137 -0
- pydantic_ai/durable_exec/dbos/_utils.py +10 -0
- pydantic_ai/mcp.py +1 -1
- pydantic_ai/messages.py +12 -0
- pydantic_ai/models/__init__.py +8 -0
- pydantic_ai/models/anthropic.py +24 -0
- pydantic_ai/models/google.py +43 -4
- pydantic_ai/models/instrumented.py +27 -14
- pydantic_ai/models/openai.py +67 -16
- pydantic_ai/providers/bedrock.py +11 -3
- pydantic_ai/tools.py +11 -0
- pydantic_ai/toolsets/function.py +7 -0
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/METADATA +6 -4
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/RECORD +23 -18
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from dbos import DBOS
|
|
8
|
+
from typing_extensions import Self
|
|
9
|
+
|
|
10
|
+
from pydantic_ai.mcp import MCPServer, ToolResult
|
|
11
|
+
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
12
|
+
from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
|
|
13
|
+
from pydantic_ai.toolsets.wrapper import WrapperToolset
|
|
14
|
+
|
|
15
|
+
from ._utils import StepConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DBOSMCPServer(WrapperToolset[AgentDepsT], ABC):
|
|
19
|
+
"""A wrapper for MCPServer that integrates with DBOS, turning call_tool and get_tools to DBOS steps."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
wrapped: MCPServer,
|
|
24
|
+
*,
|
|
25
|
+
step_name_prefix: str,
|
|
26
|
+
step_config: StepConfig,
|
|
27
|
+
):
|
|
28
|
+
super().__init__(wrapped)
|
|
29
|
+
self._step_config = step_config or {}
|
|
30
|
+
self._step_name_prefix = step_name_prefix
|
|
31
|
+
id_suffix = f'__{wrapped.id}' if wrapped.id else ''
|
|
32
|
+
self._name = f'{step_name_prefix}__mcp_server{id_suffix}'
|
|
33
|
+
|
|
34
|
+
# Wrap get_tools in a DBOS step.
|
|
35
|
+
@DBOS.step(
|
|
36
|
+
name=f'{self._name}.get_tools',
|
|
37
|
+
**self._step_config,
|
|
38
|
+
)
|
|
39
|
+
async def wrapped_get_tools_step(
|
|
40
|
+
ctx: RunContext[AgentDepsT],
|
|
41
|
+
) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
42
|
+
return await super(DBOSMCPServer, self).get_tools(ctx)
|
|
43
|
+
|
|
44
|
+
self._dbos_wrapped_get_tools_step = wrapped_get_tools_step
|
|
45
|
+
|
|
46
|
+
# Wrap call_tool in a DBOS step.
|
|
47
|
+
@DBOS.step(
|
|
48
|
+
name=f'{self._name}.call_tool',
|
|
49
|
+
**self._step_config,
|
|
50
|
+
)
|
|
51
|
+
async def wrapped_call_tool_step(
|
|
52
|
+
name: str,
|
|
53
|
+
tool_args: dict[str, Any],
|
|
54
|
+
ctx: RunContext[AgentDepsT],
|
|
55
|
+
tool: ToolsetTool[AgentDepsT],
|
|
56
|
+
) -> ToolResult:
|
|
57
|
+
return await super(DBOSMCPServer, self).call_tool(name, tool_args, ctx, tool)
|
|
58
|
+
|
|
59
|
+
self._dbos_wrapped_call_tool_step = wrapped_call_tool_step
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def id(self) -> str | None:
|
|
63
|
+
return self.wrapped.id
|
|
64
|
+
|
|
65
|
+
async def __aenter__(self) -> Self:
|
|
66
|
+
# The wrapped MCPServer enters itself around listing and calling tools
|
|
67
|
+
# so we don't need to enter it here (nor could we because we're not inside a DBOS step).
|
|
68
|
+
return self
|
|
69
|
+
|
|
70
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
def visit_and_replace(
|
|
74
|
+
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
|
|
75
|
+
) -> AbstractToolset[AgentDepsT]:
|
|
76
|
+
# DBOS-ified toolsets cannot be swapped out after the fact.
|
|
77
|
+
return self
|
|
78
|
+
|
|
79
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
80
|
+
return await self._dbos_wrapped_get_tools_step(ctx)
|
|
81
|
+
|
|
82
|
+
async def call_tool(
|
|
83
|
+
self,
|
|
84
|
+
name: str,
|
|
85
|
+
tool_args: dict[str, Any],
|
|
86
|
+
ctx: RunContext[AgentDepsT],
|
|
87
|
+
tool: ToolsetTool[AgentDepsT],
|
|
88
|
+
) -> ToolResult:
|
|
89
|
+
return await self._dbos_wrapped_call_tool_step(name, tool_args, ctx, tool)
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from dbos import DBOS
|
|
9
|
+
|
|
10
|
+
from pydantic_ai.agent import EventStreamHandler
|
|
11
|
+
from pydantic_ai.messages import (
|
|
12
|
+
ModelMessage,
|
|
13
|
+
ModelResponse,
|
|
14
|
+
ModelResponseStreamEvent,
|
|
15
|
+
)
|
|
16
|
+
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
|
|
17
|
+
from pydantic_ai.models.wrapper import WrapperModel
|
|
18
|
+
from pydantic_ai.settings import ModelSettings
|
|
19
|
+
from pydantic_ai.tools import RunContext
|
|
20
|
+
from pydantic_ai.usage import RequestUsage
|
|
21
|
+
|
|
22
|
+
from ._utils import StepConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DBOSStreamedResponse(StreamedResponse):
|
|
26
|
+
def __init__(self, model_request_parameters: ModelRequestParameters, response: ModelResponse):
|
|
27
|
+
super().__init__(model_request_parameters)
|
|
28
|
+
self.response = response
|
|
29
|
+
|
|
30
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
31
|
+
return
|
|
32
|
+
# noinspection PyUnreachableCode
|
|
33
|
+
yield
|
|
34
|
+
|
|
35
|
+
def get(self) -> ModelResponse:
|
|
36
|
+
return self.response
|
|
37
|
+
|
|
38
|
+
def usage(self) -> RequestUsage:
|
|
39
|
+
return self.response.usage # pragma: no cover
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def model_name(self) -> str:
|
|
43
|
+
return self.response.model_name or '' # pragma: no cover
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def provider_name(self) -> str:
|
|
47
|
+
return self.response.provider_name or '' # pragma: no cover
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def timestamp(self) -> datetime:
|
|
51
|
+
return self.response.timestamp # pragma: no cover
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class DBOSModel(WrapperModel):
|
|
55
|
+
"""A wrapper for Model that integrates with DBOS, turning request and request_stream to DBOS steps."""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
model: Model,
|
|
60
|
+
*,
|
|
61
|
+
step_name_prefix: str,
|
|
62
|
+
step_config: StepConfig,
|
|
63
|
+
event_stream_handler: EventStreamHandler[Any] | None = None,
|
|
64
|
+
):
|
|
65
|
+
super().__init__(model)
|
|
66
|
+
self.step_config = step_config
|
|
67
|
+
self.event_stream_handler = event_stream_handler
|
|
68
|
+
self._step_name_prefix = step_name_prefix
|
|
69
|
+
|
|
70
|
+
# Wrap the request in a DBOS step.
|
|
71
|
+
@DBOS.step(
|
|
72
|
+
name=f'{self._step_name_prefix}__model.request',
|
|
73
|
+
**self.step_config,
|
|
74
|
+
)
|
|
75
|
+
async def wrapped_request_step(
|
|
76
|
+
messages: list[ModelMessage],
|
|
77
|
+
model_settings: ModelSettings | None,
|
|
78
|
+
model_request_parameters: ModelRequestParameters,
|
|
79
|
+
) -> ModelResponse:
|
|
80
|
+
return await super(DBOSModel, self).request(messages, model_settings, model_request_parameters)
|
|
81
|
+
|
|
82
|
+
self._dbos_wrapped_request_step = wrapped_request_step
|
|
83
|
+
|
|
84
|
+
# Wrap the request_stream in a DBOS step.
|
|
85
|
+
@DBOS.step(
|
|
86
|
+
name=f'{self._step_name_prefix}__model.request_stream',
|
|
87
|
+
**self.step_config,
|
|
88
|
+
)
|
|
89
|
+
async def wrapped_request_stream_step(
|
|
90
|
+
messages: list[ModelMessage],
|
|
91
|
+
model_settings: ModelSettings | None,
|
|
92
|
+
model_request_parameters: ModelRequestParameters,
|
|
93
|
+
run_context: RunContext[Any] | None = None,
|
|
94
|
+
) -> ModelResponse:
|
|
95
|
+
async with super(DBOSModel, self).request_stream(
|
|
96
|
+
messages, model_settings, model_request_parameters, run_context
|
|
97
|
+
) as streamed_response:
|
|
98
|
+
if self.event_stream_handler is not None:
|
|
99
|
+
assert run_context is not None, (
|
|
100
|
+
'A DBOS model cannot be used with `pydantic_ai.direct.model_request_stream()` as it requires a `run_context`. Set an `event_stream_handler` on the agent and use `agent.run()` instead.'
|
|
101
|
+
)
|
|
102
|
+
await self.event_stream_handler(run_context, streamed_response)
|
|
103
|
+
|
|
104
|
+
async for _ in streamed_response:
|
|
105
|
+
pass
|
|
106
|
+
return streamed_response.get()
|
|
107
|
+
|
|
108
|
+
self._dbos_wrapped_request_stream_step = wrapped_request_stream_step
|
|
109
|
+
|
|
110
|
+
async def request(
|
|
111
|
+
self,
|
|
112
|
+
messages: list[ModelMessage],
|
|
113
|
+
model_settings: ModelSettings | None,
|
|
114
|
+
model_request_parameters: ModelRequestParameters,
|
|
115
|
+
) -> ModelResponse:
|
|
116
|
+
return await self._dbos_wrapped_request_step(messages, model_settings, model_request_parameters)
|
|
117
|
+
|
|
118
|
+
@asynccontextmanager
|
|
119
|
+
async def request_stream(
|
|
120
|
+
self,
|
|
121
|
+
messages: list[ModelMessage],
|
|
122
|
+
model_settings: ModelSettings | None,
|
|
123
|
+
model_request_parameters: ModelRequestParameters,
|
|
124
|
+
run_context: RunContext[Any] | None = None,
|
|
125
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
126
|
+
# If not in a workflow (could be in a step), just call the wrapped request_stream method.
|
|
127
|
+
if DBOS.workflow_id is None or DBOS.step_id is not None:
|
|
128
|
+
async with super().request_stream(
|
|
129
|
+
messages, model_settings, model_request_parameters, run_context
|
|
130
|
+
) as streamed_response:
|
|
131
|
+
yield streamed_response
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
response = await self._dbos_wrapped_request_stream_step(
|
|
135
|
+
messages, model_settings, model_request_parameters, run_context
|
|
136
|
+
)
|
|
137
|
+
yield DBOSStreamedResponse(model_request_parameters, response)
|
pydantic_ai/mcp.py
CHANGED
|
@@ -517,7 +517,7 @@ class MCPServerStdio(MCPServer):
|
|
|
517
517
|
f'args={self.args!r}',
|
|
518
518
|
]
|
|
519
519
|
if self.id:
|
|
520
|
-
repr_args.append(f'id={self.id!r}')
|
|
520
|
+
repr_args.append(f'id={self.id!r}')
|
|
521
521
|
return f'{self.__class__.__name__}({", ".join(repr_args)})'
|
|
522
522
|
|
|
523
523
|
|
pydantic_ai/messages.py
CHANGED
|
@@ -52,6 +52,15 @@ ImageFormat: TypeAlias = Literal['jpeg', 'png', 'gif', 'webp']
|
|
|
52
52
|
DocumentFormat: TypeAlias = Literal['csv', 'doc', 'docx', 'html', 'md', 'pdf', 'txt', 'xls', 'xlsx']
|
|
53
53
|
VideoFormat: TypeAlias = Literal['mkv', 'mov', 'mp4', 'webm', 'flv', 'mpeg', 'mpg', 'wmv', 'three_gp']
|
|
54
54
|
|
|
55
|
+
FinishReason: TypeAlias = Literal[
|
|
56
|
+
'stop',
|
|
57
|
+
'length',
|
|
58
|
+
'content_filter',
|
|
59
|
+
'tool_call',
|
|
60
|
+
'error',
|
|
61
|
+
]
|
|
62
|
+
"""Reason the model finished generating the response, normalized to OpenTelemetry values."""
|
|
63
|
+
|
|
55
64
|
|
|
56
65
|
@dataclass(repr=False)
|
|
57
66
|
class SystemPromptPart:
|
|
@@ -1032,6 +1041,9 @@ class ModelResponse:
|
|
|
1032
1041
|
] = None
|
|
1033
1042
|
"""request ID as specified by the model provider. This can be used to track the specific request to the model."""
|
|
1034
1043
|
|
|
1044
|
+
finish_reason: FinishReason | None = None
|
|
1045
|
+
"""Reason the model finished generating the response, normalized to OpenTelemetry values."""
|
|
1046
|
+
|
|
1035
1047
|
@deprecated('`price` is deprecated, use `cost` instead')
|
|
1036
1048
|
def price(self) -> genai_types.PriceCalculation: # pragma: no cover
|
|
1037
1049
|
return self.cost()
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -28,6 +28,7 @@ from ..exceptions import UserError
|
|
|
28
28
|
from ..messages import (
|
|
29
29
|
FileUrl,
|
|
30
30
|
FinalResultEvent,
|
|
31
|
+
FinishReason,
|
|
31
32
|
ModelMessage,
|
|
32
33
|
ModelRequest,
|
|
33
34
|
ModelResponse,
|
|
@@ -555,6 +556,10 @@ class StreamedResponse(ABC):
|
|
|
555
556
|
|
|
556
557
|
final_result_event: FinalResultEvent | None = field(default=None, init=False)
|
|
557
558
|
|
|
559
|
+
provider_response_id: str | None = field(default=None, init=False)
|
|
560
|
+
provider_details: dict[str, Any] | None = field(default=None, init=False)
|
|
561
|
+
finish_reason: FinishReason | None = field(default=None, init=False)
|
|
562
|
+
|
|
558
563
|
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
|
|
559
564
|
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
|
|
560
565
|
_usage: RequestUsage = field(default_factory=RequestUsage, init=False)
|
|
@@ -609,6 +614,9 @@ class StreamedResponse(ABC):
|
|
|
609
614
|
timestamp=self.timestamp,
|
|
610
615
|
usage=self.usage(),
|
|
611
616
|
provider_name=self.provider_name,
|
|
617
|
+
provider_response_id=self.provider_response_id,
|
|
618
|
+
provider_details=self.provider_details,
|
|
619
|
+
finish_reason=self.finish_reason,
|
|
612
620
|
)
|
|
613
621
|
|
|
614
622
|
def usage(self) -> RequestUsage:
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -21,6 +21,7 @@ from ..messages import (
|
|
|
21
21
|
BuiltinToolCallPart,
|
|
22
22
|
BuiltinToolReturnPart,
|
|
23
23
|
DocumentUrl,
|
|
24
|
+
FinishReason,
|
|
24
25
|
ImageUrl,
|
|
25
26
|
ModelMessage,
|
|
26
27
|
ModelRequest,
|
|
@@ -42,6 +43,16 @@ from ..settings import ModelSettings
|
|
|
42
43
|
from ..tools import ToolDefinition
|
|
43
44
|
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
|
|
44
45
|
|
|
46
|
+
_FINISH_REASON_MAP: dict[BetaStopReason, FinishReason] = {
|
|
47
|
+
'end_turn': 'stop',
|
|
48
|
+
'max_tokens': 'length',
|
|
49
|
+
'stop_sequence': 'stop',
|
|
50
|
+
'tool_use': 'tool_call',
|
|
51
|
+
'pause_turn': 'stop',
|
|
52
|
+
'refusal': 'content_filter',
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
|
|
45
56
|
try:
|
|
46
57
|
from anthropic import NOT_GIVEN, APIStatusError, AsyncStream
|
|
47
58
|
from anthropic.types.beta import (
|
|
@@ -70,6 +81,7 @@ try:
|
|
|
70
81
|
BetaServerToolUseBlock,
|
|
71
82
|
BetaServerToolUseBlockParam,
|
|
72
83
|
BetaSignatureDelta,
|
|
84
|
+
BetaStopReason,
|
|
73
85
|
BetaTextBlock,
|
|
74
86
|
BetaTextBlockParam,
|
|
75
87
|
BetaTextDelta,
|
|
@@ -326,12 +338,20 @@ class AnthropicModel(Model):
|
|
|
326
338
|
)
|
|
327
339
|
)
|
|
328
340
|
|
|
341
|
+
finish_reason: FinishReason | None = None
|
|
342
|
+
provider_details: dict[str, Any] | None = None
|
|
343
|
+
if raw_finish_reason := response.stop_reason: # pragma: no branch
|
|
344
|
+
provider_details = {'finish_reason': raw_finish_reason}
|
|
345
|
+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
346
|
+
|
|
329
347
|
return ModelResponse(
|
|
330
348
|
parts=items,
|
|
331
349
|
usage=_map_usage(response),
|
|
332
350
|
model_name=response.model,
|
|
333
351
|
provider_response_id=response.id,
|
|
334
352
|
provider_name=self._provider.name,
|
|
353
|
+
finish_reason=finish_reason,
|
|
354
|
+
provider_details=provider_details,
|
|
335
355
|
)
|
|
336
356
|
|
|
337
357
|
async def _process_streamed_response(
|
|
@@ -583,6 +603,7 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
583
603
|
async for event in self._response:
|
|
584
604
|
if isinstance(event, BetaRawMessageStartEvent):
|
|
585
605
|
self._usage = _map_usage(event)
|
|
606
|
+
self.provider_response_id = event.message.id
|
|
586
607
|
|
|
587
608
|
elif isinstance(event, BetaRawContentBlockStartEvent):
|
|
588
609
|
current_block = event.content_block
|
|
@@ -646,6 +667,9 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
646
667
|
|
|
647
668
|
elif isinstance(event, BetaRawMessageDeltaEvent):
|
|
648
669
|
self._usage = _map_usage(event)
|
|
670
|
+
if raw_finish_reason := event.delta.stop_reason: # pragma: no branch
|
|
671
|
+
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
672
|
+
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
649
673
|
|
|
650
674
|
elif isinstance(event, BetaRawContentBlockStopEvent | BetaRawMessageStopEvent): # pragma: no branch
|
|
651
675
|
current_block = None
|
pydantic_ai/models/google.py
CHANGED
|
@@ -20,6 +20,7 @@ from ..messages import (
|
|
|
20
20
|
BuiltinToolCallPart,
|
|
21
21
|
BuiltinToolReturnPart,
|
|
22
22
|
FileUrl,
|
|
23
|
+
FinishReason,
|
|
23
24
|
ModelMessage,
|
|
24
25
|
ModelRequest,
|
|
25
26
|
ModelResponse,
|
|
@@ -54,6 +55,7 @@ try:
|
|
|
54
55
|
ContentUnionDict,
|
|
55
56
|
CountTokensConfigDict,
|
|
56
57
|
ExecutableCodeDict,
|
|
58
|
+
FinishReason as GoogleFinishReason,
|
|
57
59
|
FunctionCallDict,
|
|
58
60
|
FunctionCallingConfigDict,
|
|
59
61
|
FunctionCallingConfigMode,
|
|
@@ -99,6 +101,22 @@ allow any name in the type hints.
|
|
|
99
101
|
See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
|
|
100
102
|
"""
|
|
101
103
|
|
|
104
|
+
_FINISH_REASON_MAP: dict[GoogleFinishReason, FinishReason | None] = {
|
|
105
|
+
GoogleFinishReason.FINISH_REASON_UNSPECIFIED: None,
|
|
106
|
+
GoogleFinishReason.STOP: 'stop',
|
|
107
|
+
GoogleFinishReason.MAX_TOKENS: 'length',
|
|
108
|
+
GoogleFinishReason.SAFETY: 'content_filter',
|
|
109
|
+
GoogleFinishReason.RECITATION: 'content_filter',
|
|
110
|
+
GoogleFinishReason.LANGUAGE: 'error',
|
|
111
|
+
GoogleFinishReason.OTHER: None,
|
|
112
|
+
GoogleFinishReason.BLOCKLIST: 'content_filter',
|
|
113
|
+
GoogleFinishReason.PROHIBITED_CONTENT: 'content_filter',
|
|
114
|
+
GoogleFinishReason.SPII: 'content_filter',
|
|
115
|
+
GoogleFinishReason.MALFORMED_FUNCTION_CALL: 'error',
|
|
116
|
+
GoogleFinishReason.IMAGE_SAFETY: 'content_filter',
|
|
117
|
+
GoogleFinishReason.UNEXPECTED_TOOL_CALL: 'error',
|
|
118
|
+
}
|
|
119
|
+
|
|
102
120
|
|
|
103
121
|
class GoogleModelSettings(ModelSettings, total=False):
|
|
104
122
|
"""Settings used for a Gemini model request."""
|
|
@@ -129,6 +147,12 @@ class GoogleModelSettings(ModelSettings, total=False):
|
|
|
129
147
|
See <https://ai.google.dev/api/generate-content#MediaResolution> for more information.
|
|
130
148
|
"""
|
|
131
149
|
|
|
150
|
+
google_cached_content: str
|
|
151
|
+
"""The name of the cached content to use for the model.
|
|
152
|
+
|
|
153
|
+
See <https://ai.google.dev/gemini-api/docs/caching> for more information.
|
|
154
|
+
"""
|
|
155
|
+
|
|
132
156
|
|
|
133
157
|
@dataclass(init=False)
|
|
134
158
|
class GoogleModel(Model):
|
|
@@ -377,6 +401,7 @@ class GoogleModel(Model):
|
|
|
377
401
|
thinking_config=model_settings.get('google_thinking_config'),
|
|
378
402
|
labels=model_settings.get('google_labels'),
|
|
379
403
|
media_resolution=model_settings.get('google_video_resolution'),
|
|
404
|
+
cached_content=model_settings.get('google_cached_content'),
|
|
380
405
|
tools=cast(ToolListUnionDict, tools),
|
|
381
406
|
tool_config=tool_config,
|
|
382
407
|
response_mime_type=response_mime_type,
|
|
@@ -396,11 +421,14 @@ class GoogleModel(Model):
|
|
|
396
421
|
'Content field missing from Gemini response', str(response)
|
|
397
422
|
) # pragma: no cover
|
|
398
423
|
parts = candidate.content.parts or []
|
|
399
|
-
|
|
424
|
+
|
|
425
|
+
vendor_id = response.response_id
|
|
400
426
|
vendor_details: dict[str, Any] | None = None
|
|
401
|
-
finish_reason =
|
|
402
|
-
if finish_reason: # pragma: no branch
|
|
403
|
-
vendor_details = {'finish_reason':
|
|
427
|
+
finish_reason: FinishReason | None = None
|
|
428
|
+
if raw_finish_reason := candidate.finish_reason: # pragma: no branch
|
|
429
|
+
vendor_details = {'finish_reason': raw_finish_reason.value}
|
|
430
|
+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
431
|
+
|
|
404
432
|
usage = _metadata_as_usage(response)
|
|
405
433
|
return _process_response_from_parts(
|
|
406
434
|
parts,
|
|
@@ -409,6 +437,7 @@ class GoogleModel(Model):
|
|
|
409
437
|
usage,
|
|
410
438
|
vendor_id=vendor_id,
|
|
411
439
|
vendor_details=vendor_details,
|
|
440
|
+
finish_reason=finish_reason,
|
|
412
441
|
)
|
|
413
442
|
|
|
414
443
|
async def _process_streamed_response(
|
|
@@ -543,6 +572,14 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
543
572
|
|
|
544
573
|
assert chunk.candidates is not None
|
|
545
574
|
candidate = chunk.candidates[0]
|
|
575
|
+
|
|
576
|
+
if chunk.response_id: # pragma: no branch
|
|
577
|
+
self.provider_response_id = chunk.response_id
|
|
578
|
+
|
|
579
|
+
if raw_finish_reason := candidate.finish_reason:
|
|
580
|
+
self.provider_details = {'finish_reason': raw_finish_reason.value}
|
|
581
|
+
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
582
|
+
|
|
546
583
|
if candidate.content is None or candidate.content.parts is None:
|
|
547
584
|
if candidate.finish_reason == 'STOP': # pragma: no cover
|
|
548
585
|
# Normal completion - skip this chunk
|
|
@@ -625,6 +662,7 @@ def _process_response_from_parts(
|
|
|
625
662
|
usage: usage.RequestUsage,
|
|
626
663
|
vendor_id: str | None,
|
|
627
664
|
vendor_details: dict[str, Any] | None = None,
|
|
665
|
+
finish_reason: FinishReason | None = None,
|
|
628
666
|
) -> ModelResponse:
|
|
629
667
|
items: list[ModelResponsePart] = []
|
|
630
668
|
for part in parts:
|
|
@@ -665,6 +703,7 @@ def _process_response_from_parts(
|
|
|
665
703
|
provider_response_id=vendor_id,
|
|
666
704
|
provider_details=vendor_details,
|
|
667
705
|
provider_name=provider_name,
|
|
706
|
+
finish_reason=finish_reason,
|
|
668
707
|
)
|
|
669
708
|
|
|
670
709
|
|
|
@@ -221,7 +221,10 @@ class InstrumentationSettings:
|
|
|
221
221
|
_otel_messages.ChatMessage(role='system' if is_system else 'user', parts=message_parts)
|
|
222
222
|
)
|
|
223
223
|
elif isinstance(message, ModelResponse): # pragma: no branch
|
|
224
|
-
|
|
224
|
+
otel_message = _otel_messages.OutputMessage(role='assistant', parts=message.otel_message_parts(self))
|
|
225
|
+
if message.finish_reason is not None:
|
|
226
|
+
otel_message['finish_reason'] = message.finish_reason
|
|
227
|
+
result.append(otel_message)
|
|
225
228
|
return result
|
|
226
229
|
|
|
227
230
|
def handle_messages(self, input_messages: list[ModelMessage], response: ModelResponse, system: str, span: Span):
|
|
@@ -246,12 +249,10 @@ class InstrumentationSettings:
|
|
|
246
249
|
else:
|
|
247
250
|
output_messages = self.messages_to_otel_messages([response])
|
|
248
251
|
assert len(output_messages) == 1
|
|
249
|
-
output_message =
|
|
250
|
-
if response.provider_details and 'finish_reason' in response.provider_details:
|
|
251
|
-
output_message['finish_reason'] = response.provider_details['finish_reason']
|
|
252
|
+
output_message = output_messages[0]
|
|
252
253
|
instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage]
|
|
253
254
|
system_instructions_attributes = self.system_instructions_attributes(instructions)
|
|
254
|
-
attributes = {
|
|
255
|
+
attributes: dict[str, AttributeValue] = {
|
|
255
256
|
'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)),
|
|
256
257
|
'gen_ai.output.messages': json.dumps([output_message]),
|
|
257
258
|
**system_instructions_attributes,
|
|
@@ -420,17 +421,25 @@ class InstrumentedModel(WrapperModel):
|
|
|
420
421
|
return
|
|
421
422
|
|
|
422
423
|
self.instrumentation_settings.handle_messages(messages, response, system, span)
|
|
424
|
+
|
|
425
|
+
attributes_to_set = {
|
|
426
|
+
**response.usage.opentelemetry_attributes(),
|
|
427
|
+
'gen_ai.response.model': response_model,
|
|
428
|
+
}
|
|
423
429
|
try:
|
|
424
|
-
|
|
430
|
+
attributes_to_set['operation.cost'] = float(response.cost().total_price)
|
|
425
431
|
except LookupError:
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
'
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
432
|
+
# The cost of this provider/model is unknown, which is common.
|
|
433
|
+
pass
|
|
434
|
+
except Exception as e:
|
|
435
|
+
warnings.warn(
|
|
436
|
+
f'Failed to get cost from response: {type(e).__name__}: {e}', CostCalculationFailedWarning
|
|
437
|
+
)
|
|
438
|
+
if response.provider_response_id is not None:
|
|
439
|
+
attributes_to_set['gen_ai.response.id'] = response.provider_response_id
|
|
440
|
+
if response.finish_reason is not None:
|
|
441
|
+
attributes_to_set['gen_ai.response.finish_reasons'] = [response.finish_reason]
|
|
442
|
+
span.set_attributes(attributes_to_set)
|
|
434
443
|
span.update_name(f'{operation} {request_model}')
|
|
435
444
|
|
|
436
445
|
yield finish
|
|
@@ -478,3 +487,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
478
487
|
return str(value)
|
|
479
488
|
except Exception as e:
|
|
480
489
|
return f'Unable to serialize: {e}'
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
class CostCalculationFailedWarning(Warning):
|
|
493
|
+
"""Warning raised when cost calculation fails."""
|