pydantic-ai-slim 1.9.0__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_agent_graph.py +18 -14
- pydantic_ai/_output.py +20 -105
- pydantic_ai/_run_context.py +8 -2
- pydantic_ai/_tool_manager.py +30 -11
- pydantic_ai/_utils.py +18 -0
- pydantic_ai/agent/__init__.py +34 -32
- pydantic_ai/agent/abstract.py +155 -3
- pydantic_ai/agent/wrapper.py +5 -0
- pydantic_ai/common_tools/duckduckgo.py +1 -1
- pydantic_ai/durable_exec/dbos/_agent.py +28 -0
- pydantic_ai/durable_exec/prefect/_agent.py +25 -0
- pydantic_ai/durable_exec/temporal/_agent.py +25 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +23 -73
- pydantic_ai/durable_exec/temporal/_mcp_server.py +30 -30
- pydantic_ai/durable_exec/temporal/_run_context.py +9 -3
- pydantic_ai/durable_exec/temporal/_toolset.py +67 -3
- pydantic_ai/mcp.py +4 -4
- pydantic_ai/messages.py +11 -2
- pydantic_ai/models/__init__.py +80 -35
- pydantic_ai/models/anthropic.py +27 -8
- pydantic_ai/models/bedrock.py +3 -3
- pydantic_ai/models/cohere.py +5 -3
- pydantic_ai/models/fallback.py +25 -4
- pydantic_ai/models/function.py +8 -0
- pydantic_ai/models/gemini.py +3 -3
- pydantic_ai/models/google.py +25 -22
- pydantic_ai/models/groq.py +5 -3
- pydantic_ai/models/huggingface.py +3 -3
- pydantic_ai/models/instrumented.py +29 -13
- pydantic_ai/models/mistral.py +6 -4
- pydantic_ai/models/openai.py +15 -6
- pydantic_ai/models/outlines.py +21 -12
- pydantic_ai/models/wrapper.py +1 -1
- pydantic_ai/output.py +3 -2
- pydantic_ai/profiles/openai.py +5 -2
- pydantic_ai/providers/anthropic.py +2 -2
- pydantic_ai/providers/openrouter.py +3 -0
- pydantic_ai/result.py +159 -4
- pydantic_ai/tools.py +12 -10
- pydantic_ai/ui/_adapter.py +2 -2
- pydantic_ai/ui/_event_stream.py +4 -4
- pydantic_ai/ui/ag_ui/_event_stream.py +11 -2
- pydantic_ai/ui/ag_ui/app.py +8 -1
- {pydantic_ai_slim-1.9.0.dist-info → pydantic_ai_slim-1.12.0.dist-info}/METADATA +9 -7
- {pydantic_ai_slim-1.9.0.dist-info → pydantic_ai_slim-1.12.0.dist-info}/RECORD +48 -48
- {pydantic_ai_slim-1.9.0.dist-info → pydantic_ai_slim-1.12.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.9.0.dist-info → pydantic_ai_slim-1.12.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.9.0.dist-info → pydantic_ai_slim-1.12.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -39,7 +39,7 @@ from ..messages import (
|
|
|
39
39
|
from ..profiles import ModelProfileSpec
|
|
40
40
|
from ..providers import Provider, infer_provider
|
|
41
41
|
from ..providers.anthropic import AsyncAnthropicClient
|
|
42
|
-
from ..settings import ModelSettings
|
|
42
|
+
from ..settings import ModelSettings, merge_model_settings
|
|
43
43
|
from ..tools import ToolDefinition
|
|
44
44
|
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
|
|
45
45
|
|
|
@@ -240,6 +240,27 @@ class AnthropicModel(Model):
|
|
|
240
240
|
async with response:
|
|
241
241
|
yield await self._process_streamed_response(response, model_request_parameters)
|
|
242
242
|
|
|
243
|
+
def prepare_request(
|
|
244
|
+
self, model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters
|
|
245
|
+
) -> tuple[ModelSettings | None, ModelRequestParameters]:
|
|
246
|
+
settings = merge_model_settings(self.settings, model_settings)
|
|
247
|
+
if (
|
|
248
|
+
model_request_parameters.output_tools
|
|
249
|
+
and settings
|
|
250
|
+
and (thinking := settings.get('anthropic_thinking'))
|
|
251
|
+
and thinking.get('type') == 'enabled'
|
|
252
|
+
):
|
|
253
|
+
if model_request_parameters.output_mode == 'auto':
|
|
254
|
+
model_request_parameters = replace(model_request_parameters, output_mode='prompted')
|
|
255
|
+
elif (
|
|
256
|
+
model_request_parameters.output_mode == 'tool' and not model_request_parameters.allow_text_output
|
|
257
|
+
): # pragma: no branch
|
|
258
|
+
# This would result in `tool_choice=required`, which Anthropic does not support with thinking.
|
|
259
|
+
raise UserError(
|
|
260
|
+
'Anthropic does not support thinking and output tools at the same time. Use `output_type=PromptedOutput(...)` instead.'
|
|
261
|
+
)
|
|
262
|
+
return super().prepare_request(model_settings, model_request_parameters)
|
|
263
|
+
|
|
243
264
|
@overload
|
|
244
265
|
async def _messages_create(
|
|
245
266
|
self,
|
|
@@ -278,17 +299,13 @@ class AnthropicModel(Model):
|
|
|
278
299
|
else:
|
|
279
300
|
if not model_request_parameters.allow_text_output:
|
|
280
301
|
tool_choice = {'type': 'any'}
|
|
281
|
-
if (thinking := model_settings.get('anthropic_thinking')) and thinking.get('type') == 'enabled':
|
|
282
|
-
raise UserError(
|
|
283
|
-
'Anthropic does not support thinking and output tools at the same time. Use `output_type=PromptedOutput(...)` instead.'
|
|
284
|
-
)
|
|
285
302
|
else:
|
|
286
303
|
tool_choice = {'type': 'auto'}
|
|
287
304
|
|
|
288
305
|
if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
|
|
289
306
|
tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
|
|
290
307
|
|
|
291
|
-
system_prompt, anthropic_messages = await self._map_message(messages)
|
|
308
|
+
system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters)
|
|
292
309
|
|
|
293
310
|
try:
|
|
294
311
|
extra_headers = model_settings.get('extra_headers', {})
|
|
@@ -446,7 +463,9 @@ class AnthropicModel(Model):
|
|
|
446
463
|
)
|
|
447
464
|
return tools, mcp_servers, beta_features
|
|
448
465
|
|
|
449
|
-
async def _map_message(
|
|
466
|
+
async def _map_message( # noqa: C901
|
|
467
|
+
self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters
|
|
468
|
+
) -> tuple[str, list[BetaMessageParam]]:
|
|
450
469
|
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
|
|
451
470
|
system_prompt_parts: list[str] = []
|
|
452
471
|
anthropic_messages: list[BetaMessageParam] = []
|
|
@@ -615,7 +634,7 @@ class AnthropicModel(Model):
|
|
|
615
634
|
anthropic_messages.append(BetaMessageParam(role='assistant', content=assistant_content_params))
|
|
616
635
|
else:
|
|
617
636
|
assert_never(m)
|
|
618
|
-
if instructions := self._get_instructions(messages):
|
|
637
|
+
if instructions := self._get_instructions(messages, model_request_parameters):
|
|
619
638
|
system_prompt_parts.insert(0, instructions)
|
|
620
639
|
system_prompt = '\n\n'.join(system_prompt_parts)
|
|
621
640
|
return system_prompt, anthropic_messages
|
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -374,7 +374,7 @@ class BedrockConverseModel(Model):
|
|
|
374
374
|
model_settings: BedrockModelSettings | None,
|
|
375
375
|
model_request_parameters: ModelRequestParameters,
|
|
376
376
|
) -> ConverseResponseTypeDef | ConverseStreamResponseTypeDef:
|
|
377
|
-
system_prompt, bedrock_messages = await self._map_messages(messages)
|
|
377
|
+
system_prompt, bedrock_messages = await self._map_messages(messages, model_request_parameters)
|
|
378
378
|
inference_config = self._map_inference_config(model_settings)
|
|
379
379
|
|
|
380
380
|
params: ConverseRequestTypeDef = {
|
|
@@ -450,7 +450,7 @@ class BedrockConverseModel(Model):
|
|
|
450
450
|
return tool_config
|
|
451
451
|
|
|
452
452
|
async def _map_messages( # noqa: C901
|
|
453
|
-
self, messages: list[ModelMessage]
|
|
453
|
+
self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters
|
|
454
454
|
) -> tuple[list[SystemContentBlockTypeDef], list[MessageUnionTypeDef]]:
|
|
455
455
|
"""Maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`.
|
|
456
456
|
|
|
@@ -561,7 +561,7 @@ class BedrockConverseModel(Model):
|
|
|
561
561
|
processed_messages.append(current_message)
|
|
562
562
|
last_message = cast(dict[str, Any], current_message)
|
|
563
563
|
|
|
564
|
-
if instructions := self._get_instructions(messages):
|
|
564
|
+
if instructions := self._get_instructions(messages, model_request_parameters):
|
|
565
565
|
system_prompt.insert(0, {'text': instructions})
|
|
566
566
|
|
|
567
567
|
return system_prompt, processed_messages
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -178,7 +178,7 @@ class CohereModel(Model):
|
|
|
178
178
|
if model_request_parameters.builtin_tools:
|
|
179
179
|
raise UserError('Cohere does not support built-in tools')
|
|
180
180
|
|
|
181
|
-
cohere_messages = self._map_messages(messages)
|
|
181
|
+
cohere_messages = self._map_messages(messages, model_request_parameters)
|
|
182
182
|
try:
|
|
183
183
|
return await self.client.chat(
|
|
184
184
|
model=self._model_name,
|
|
@@ -229,7 +229,9 @@ class CohereModel(Model):
|
|
|
229
229
|
provider_details=provider_details,
|
|
230
230
|
)
|
|
231
231
|
|
|
232
|
-
def _map_messages(
|
|
232
|
+
def _map_messages(
|
|
233
|
+
self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters
|
|
234
|
+
) -> list[ChatMessageV2]:
|
|
233
235
|
"""Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
|
|
234
236
|
cohere_messages: list[ChatMessageV2] = []
|
|
235
237
|
for message in messages:
|
|
@@ -268,7 +270,7 @@ class CohereModel(Model):
|
|
|
268
270
|
cohere_messages.append(message_param)
|
|
269
271
|
else:
|
|
270
272
|
assert_never(message)
|
|
271
|
-
if instructions := self._get_instructions(messages):
|
|
273
|
+
if instructions := self._get_instructions(messages, model_request_parameters):
|
|
272
274
|
cohere_messages.insert(0, SystemChatMessageV2(role='system', content=instructions))
|
|
273
275
|
return cohere_messages
|
|
274
276
|
|
pydantic_ai/models/fallback.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
from collections.abc import AsyncIterator, Callable
|
|
4
4
|
from contextlib import AsyncExitStack, asynccontextmanager, suppress
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
|
+
from functools import cached_property
|
|
6
7
|
from typing import TYPE_CHECKING, Any
|
|
7
8
|
|
|
8
9
|
from opentelemetry.trace import get_current_span
|
|
@@ -11,6 +12,7 @@ from pydantic_ai._run_context import RunContext
|
|
|
11
12
|
from pydantic_ai.models.instrumented import InstrumentedModel
|
|
12
13
|
|
|
13
14
|
from ..exceptions import FallbackExceptionGroup, ModelHTTPError
|
|
15
|
+
from ..profiles import ModelProfile
|
|
14
16
|
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
|
|
15
17
|
|
|
16
18
|
if TYPE_CHECKING:
|
|
@@ -78,6 +80,7 @@ class FallbackModel(Model):
|
|
|
78
80
|
|
|
79
81
|
for model in self.models:
|
|
80
82
|
try:
|
|
83
|
+
_, prepared_parameters = model.prepare_request(model_settings, model_request_parameters)
|
|
81
84
|
response = await model.request(messages, model_settings, model_request_parameters)
|
|
82
85
|
except Exception as exc:
|
|
83
86
|
if self._fallback_on(exc):
|
|
@@ -85,7 +88,7 @@ class FallbackModel(Model):
|
|
|
85
88
|
continue
|
|
86
89
|
raise exc
|
|
87
90
|
|
|
88
|
-
self._set_span_attributes(model)
|
|
91
|
+
self._set_span_attributes(model, prepared_parameters)
|
|
89
92
|
return response
|
|
90
93
|
|
|
91
94
|
raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
|
|
@@ -104,6 +107,7 @@ class FallbackModel(Model):
|
|
|
104
107
|
for model in self.models:
|
|
105
108
|
async with AsyncExitStack() as stack:
|
|
106
109
|
try:
|
|
110
|
+
_, prepared_parameters = model.prepare_request(model_settings, model_request_parameters)
|
|
107
111
|
response = await stack.enter_async_context(
|
|
108
112
|
model.request_stream(messages, model_settings, model_request_parameters, run_context)
|
|
109
113
|
)
|
|
@@ -113,19 +117,36 @@ class FallbackModel(Model):
|
|
|
113
117
|
continue
|
|
114
118
|
raise exc # pragma: no cover
|
|
115
119
|
|
|
116
|
-
self._set_span_attributes(model)
|
|
120
|
+
self._set_span_attributes(model, prepared_parameters)
|
|
117
121
|
yield response
|
|
118
122
|
return
|
|
119
123
|
|
|
120
124
|
raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
|
|
121
125
|
|
|
122
|
-
|
|
126
|
+
@cached_property
|
|
127
|
+
def profile(self) -> ModelProfile:
|
|
128
|
+
raise NotImplementedError('FallbackModel does not have its own model profile.')
|
|
129
|
+
|
|
130
|
+
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
131
|
+
return model_request_parameters # pragma: no cover
|
|
132
|
+
|
|
133
|
+
def prepare_request(
|
|
134
|
+
self, model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters
|
|
135
|
+
) -> tuple[ModelSettings | None, ModelRequestParameters]:
|
|
136
|
+
return model_settings, model_request_parameters
|
|
137
|
+
|
|
138
|
+
def _set_span_attributes(self, model: Model, model_request_parameters: ModelRequestParameters):
|
|
123
139
|
with suppress(Exception):
|
|
124
140
|
span = get_current_span()
|
|
125
141
|
if span.is_recording():
|
|
126
142
|
attributes = getattr(span, 'attributes', {})
|
|
127
143
|
if attributes.get('gen_ai.request.model') == self.model_name: # pragma: no branch
|
|
128
|
-
span.set_attributes(
|
|
144
|
+
span.set_attributes(
|
|
145
|
+
{
|
|
146
|
+
**InstrumentedModel.model_attributes(model),
|
|
147
|
+
**InstrumentedModel.model_request_parameters_attributes(model_request_parameters),
|
|
148
|
+
}
|
|
149
|
+
)
|
|
129
150
|
|
|
130
151
|
|
|
131
152
|
def _default_fallback_condition_factory(exceptions: tuple[type[Exception], ...]) -> Callable[[Exception], bool]:
|
pydantic_ai/models/function.py
CHANGED
|
@@ -135,6 +135,8 @@ class FunctionModel(Model):
|
|
|
135
135
|
allow_text_output=model_request_parameters.allow_text_output,
|
|
136
136
|
output_tools=model_request_parameters.output_tools,
|
|
137
137
|
model_settings=model_settings,
|
|
138
|
+
model_request_parameters=model_request_parameters,
|
|
139
|
+
instructions=self._get_instructions(messages, model_request_parameters),
|
|
138
140
|
)
|
|
139
141
|
|
|
140
142
|
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
@@ -168,6 +170,8 @@ class FunctionModel(Model):
|
|
|
168
170
|
allow_text_output=model_request_parameters.allow_text_output,
|
|
169
171
|
output_tools=model_request_parameters.output_tools,
|
|
170
172
|
model_settings=model_settings,
|
|
173
|
+
model_request_parameters=model_request_parameters,
|
|
174
|
+
instructions=self._get_instructions(messages, model_request_parameters),
|
|
171
175
|
)
|
|
172
176
|
|
|
173
177
|
assert self.stream_function is not None, (
|
|
@@ -216,6 +220,10 @@ class AgentInfo:
|
|
|
216
220
|
"""The tools that can called to produce the final output of the run."""
|
|
217
221
|
model_settings: ModelSettings | None
|
|
218
222
|
"""The model settings passed to the run call."""
|
|
223
|
+
model_request_parameters: ModelRequestParameters
|
|
224
|
+
"""The model request parameters passed to the run call."""
|
|
225
|
+
instructions: str | None
|
|
226
|
+
"""The instructions passed to model."""
|
|
219
227
|
|
|
220
228
|
|
|
221
229
|
@dataclass
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -218,7 +218,7 @@ class GeminiModel(Model):
|
|
|
218
218
|
) -> AsyncIterator[HTTPResponse]:
|
|
219
219
|
tools = self._get_tools(model_request_parameters)
|
|
220
220
|
tool_config = self._get_tool_config(model_request_parameters, tools)
|
|
221
|
-
sys_prompt_parts, contents = await self._message_to_gemini_content(messages)
|
|
221
|
+
sys_prompt_parts, contents = await self._message_to_gemini_content(messages, model_request_parameters)
|
|
222
222
|
|
|
223
223
|
request_data = _GeminiRequest(contents=contents)
|
|
224
224
|
if sys_prompt_parts:
|
|
@@ -331,7 +331,7 @@ class GeminiModel(Model):
|
|
|
331
331
|
)
|
|
332
332
|
|
|
333
333
|
async def _message_to_gemini_content(
|
|
334
|
-
self, messages: list[ModelMessage]
|
|
334
|
+
self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters
|
|
335
335
|
) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
|
|
336
336
|
sys_prompt_parts: list[_GeminiTextPart] = []
|
|
337
337
|
contents: list[_GeminiContent] = []
|
|
@@ -361,7 +361,7 @@ class GeminiModel(Model):
|
|
|
361
361
|
contents.append(_content_model_response(m))
|
|
362
362
|
else:
|
|
363
363
|
assert_never(m)
|
|
364
|
-
if instructions := self._get_instructions(messages):
|
|
364
|
+
if instructions := self._get_instructions(messages, model_request_parameters):
|
|
365
365
|
sys_prompt_parts.insert(0, _GeminiTextPart(text=instructions))
|
|
366
366
|
return sys_prompt_parts, contents
|
|
367
367
|
|
pydantic_ai/models/google.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import base64
|
|
4
4
|
from collections.abc import AsyncIterator, Awaitable
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
|
-
from dataclasses import dataclass, field
|
|
6
|
+
from dataclasses import dataclass, field, replace
|
|
7
7
|
from datetime import datetime
|
|
8
8
|
from typing import Any, Literal, cast, overload
|
|
9
9
|
from uuid import uuid4
|
|
@@ -224,6 +224,18 @@ class GoogleModel(Model):
|
|
|
224
224
|
"""The model provider."""
|
|
225
225
|
return self._provider.name
|
|
226
226
|
|
|
227
|
+
def prepare_request(
|
|
228
|
+
self, model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters
|
|
229
|
+
) -> tuple[ModelSettings | None, ModelRequestParameters]:
|
|
230
|
+
if model_request_parameters.builtin_tools and model_request_parameters.output_tools:
|
|
231
|
+
if model_request_parameters.output_mode == 'auto':
|
|
232
|
+
model_request_parameters = replace(model_request_parameters, output_mode='prompted')
|
|
233
|
+
else:
|
|
234
|
+
raise UserError(
|
|
235
|
+
'Google does not support output tools and built-in tools at the same time. Use `output_type=PromptedOutput(...)` instead.'
|
|
236
|
+
)
|
|
237
|
+
return super().prepare_request(model_settings, model_request_parameters)
|
|
238
|
+
|
|
227
239
|
async def request(
|
|
228
240
|
self,
|
|
229
241
|
messages: list[ModelMessage],
|
|
@@ -320,12 +332,8 @@ class GoogleModel(Model):
|
|
|
320
332
|
]
|
|
321
333
|
|
|
322
334
|
if model_request_parameters.builtin_tools:
|
|
323
|
-
if model_request_parameters.output_tools:
|
|
324
|
-
raise UserError(
|
|
325
|
-
'Gemini does not support output tools and built-in tools at the same time. Use `output_type=PromptedOutput(...)` instead.'
|
|
326
|
-
)
|
|
327
335
|
if model_request_parameters.function_tools:
|
|
328
|
-
raise UserError('
|
|
336
|
+
raise UserError('Google does not support function tools and built-in tools at the same time.')
|
|
329
337
|
|
|
330
338
|
for tool in model_request_parameters.builtin_tools:
|
|
331
339
|
if isinstance(tool, WebSearchTool):
|
|
@@ -402,7 +410,7 @@ class GoogleModel(Model):
|
|
|
402
410
|
if model_request_parameters.output_mode == 'native':
|
|
403
411
|
if tools:
|
|
404
412
|
raise UserError(
|
|
405
|
-
'
|
|
413
|
+
'Google does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.'
|
|
406
414
|
)
|
|
407
415
|
response_mime_type = 'application/json'
|
|
408
416
|
output_object = model_request_parameters.output_object
|
|
@@ -414,7 +422,7 @@ class GoogleModel(Model):
|
|
|
414
422
|
response_mime_type = 'application/json'
|
|
415
423
|
|
|
416
424
|
tool_config = self._get_tool_config(model_request_parameters, tools)
|
|
417
|
-
system_instruction, contents = await self._map_messages(messages)
|
|
425
|
+
system_instruction, contents = await self._map_messages(messages, model_request_parameters)
|
|
418
426
|
|
|
419
427
|
modalities = [Modality.TEXT.value]
|
|
420
428
|
if self.profile.supports_image_output:
|
|
@@ -471,11 +479,9 @@ class GoogleModel(Model):
|
|
|
471
479
|
raise UnexpectedModelBehavior(
|
|
472
480
|
f'Content filter {raw_finish_reason.value!r} triggered', response.model_dump_json()
|
|
473
481
|
)
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
) # pragma: no cover
|
|
478
|
-
parts = candidate.content.parts or []
|
|
482
|
+
parts = [] # pragma: no cover
|
|
483
|
+
else:
|
|
484
|
+
parts = candidate.content.parts or []
|
|
479
485
|
|
|
480
486
|
usage = _metadata_as_usage(response)
|
|
481
487
|
return _process_response_from_parts(
|
|
@@ -506,7 +512,9 @@ class GoogleModel(Model):
|
|
|
506
512
|
_provider_name=self._provider.name,
|
|
507
513
|
)
|
|
508
514
|
|
|
509
|
-
async def _map_messages(
|
|
515
|
+
async def _map_messages(
|
|
516
|
+
self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters
|
|
517
|
+
) -> tuple[ContentDict | None, list[ContentUnionDict]]:
|
|
510
518
|
contents: list[ContentUnionDict] = []
|
|
511
519
|
system_parts: list[PartDict] = []
|
|
512
520
|
|
|
@@ -553,7 +561,7 @@ class GoogleModel(Model):
|
|
|
553
561
|
contents.append(_content_model_response(m, self.system))
|
|
554
562
|
else:
|
|
555
563
|
assert_never(m)
|
|
556
|
-
if instructions := self._get_instructions(messages):
|
|
564
|
+
if instructions := self._get_instructions(messages, model_request_parameters):
|
|
557
565
|
system_parts.insert(0, {'text': instructions})
|
|
558
566
|
system_instruction = ContentDict(role='user', parts=system_parts) if system_parts else None
|
|
559
567
|
return system_instruction, contents
|
|
@@ -649,17 +657,12 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
649
657
|
# )
|
|
650
658
|
|
|
651
659
|
if candidate.content is None or candidate.content.parts is None:
|
|
652
|
-
if self.finish_reason == '
|
|
653
|
-
# Normal completion - skip this chunk
|
|
654
|
-
continue
|
|
655
|
-
elif self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover
|
|
660
|
+
if self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover
|
|
656
661
|
raise UnexpectedModelBehavior(
|
|
657
662
|
f'Content filter {raw_finish_reason.value!r} triggered', chunk.model_dump_json()
|
|
658
663
|
)
|
|
659
664
|
else: # pragma: no cover
|
|
660
|
-
|
|
661
|
-
'Content field missing from streaming Gemini response', chunk.model_dump_json()
|
|
662
|
-
)
|
|
665
|
+
continue
|
|
663
666
|
|
|
664
667
|
parts = candidate.content.parts
|
|
665
668
|
if not parts:
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -272,7 +272,7 @@ class GroqModel(Model):
|
|
|
272
272
|
else:
|
|
273
273
|
tool_choice = 'auto'
|
|
274
274
|
|
|
275
|
-
groq_messages = self._map_messages(messages)
|
|
275
|
+
groq_messages = self._map_messages(messages, model_request_parameters)
|
|
276
276
|
|
|
277
277
|
response_format: chat.completion_create_params.ResponseFormat | None = None
|
|
278
278
|
if model_request_parameters.output_mode == 'native':
|
|
@@ -388,7 +388,9 @@ class GroqModel(Model):
|
|
|
388
388
|
)
|
|
389
389
|
return tools
|
|
390
390
|
|
|
391
|
-
def _map_messages(
|
|
391
|
+
def _map_messages(
|
|
392
|
+
self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters
|
|
393
|
+
) -> list[chat.ChatCompletionMessageParam]:
|
|
392
394
|
"""Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
|
|
393
395
|
groq_messages: list[chat.ChatCompletionMessageParam] = []
|
|
394
396
|
for message in messages:
|
|
@@ -423,7 +425,7 @@ class GroqModel(Model):
|
|
|
423
425
|
groq_messages.append(message_param)
|
|
424
426
|
else:
|
|
425
427
|
assert_never(message)
|
|
426
|
-
if instructions := self._get_instructions(messages):
|
|
428
|
+
if instructions := self._get_instructions(messages, model_request_parameters):
|
|
427
429
|
groq_messages.insert(0, chat.ChatCompletionSystemMessageParam(role='system', content=instructions))
|
|
428
430
|
return groq_messages
|
|
429
431
|
|
|
@@ -231,7 +231,7 @@ class HuggingFaceModel(Model):
|
|
|
231
231
|
if model_request_parameters.builtin_tools:
|
|
232
232
|
raise UserError('HuggingFace does not support built-in tools')
|
|
233
233
|
|
|
234
|
-
hf_messages = await self._map_messages(messages)
|
|
234
|
+
hf_messages = await self._map_messages(messages, model_request_parameters)
|
|
235
235
|
|
|
236
236
|
try:
|
|
237
237
|
return await self.client.chat.completions.create( # type: ignore
|
|
@@ -322,7 +322,7 @@ class HuggingFaceModel(Model):
|
|
|
322
322
|
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
323
323
|
|
|
324
324
|
async def _map_messages(
|
|
325
|
-
self, messages: list[ModelMessage]
|
|
325
|
+
self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters
|
|
326
326
|
) -> list[ChatCompletionInputMessage | ChatCompletionOutputMessage]:
|
|
327
327
|
"""Just maps a `pydantic_ai.Message` to a `huggingface_hub.ChatCompletionInputMessage`."""
|
|
328
328
|
hf_messages: list[ChatCompletionInputMessage | ChatCompletionOutputMessage] = []
|
|
@@ -359,7 +359,7 @@ class HuggingFaceModel(Model):
|
|
|
359
359
|
hf_messages.append(message_param)
|
|
360
360
|
else:
|
|
361
361
|
assert_never(message)
|
|
362
|
-
if instructions := self._get_instructions(messages):
|
|
362
|
+
if instructions := self._get_instructions(messages, model_request_parameters):
|
|
363
363
|
hf_messages.insert(0, ChatCompletionInputMessage(content=instructions, role='system')) # type: ignore
|
|
364
364
|
return hf_messages
|
|
365
365
|
|
|
@@ -178,17 +178,20 @@ class InstrumentationSettings:
|
|
|
178
178
|
description='Monetary cost',
|
|
179
179
|
)
|
|
180
180
|
|
|
181
|
-
def messages_to_otel_events(
|
|
181
|
+
def messages_to_otel_events(
|
|
182
|
+
self, messages: list[ModelMessage], parameters: ModelRequestParameters | None = None
|
|
183
|
+
) -> list[Event]:
|
|
182
184
|
"""Convert a list of model messages to OpenTelemetry events.
|
|
183
185
|
|
|
184
186
|
Args:
|
|
185
187
|
messages: The messages to convert.
|
|
188
|
+
parameters: The model request parameters.
|
|
186
189
|
|
|
187
190
|
Returns:
|
|
188
191
|
A list of OpenTelemetry events.
|
|
189
192
|
"""
|
|
190
193
|
events: list[Event] = []
|
|
191
|
-
instructions = InstrumentedModel._get_instructions(messages) # pyright: ignore [reportPrivateUsage]
|
|
194
|
+
instructions = InstrumentedModel._get_instructions(messages, parameters) # pyright: ignore [reportPrivateUsage]
|
|
192
195
|
if instructions is not None:
|
|
193
196
|
events.append(
|
|
194
197
|
Event(
|
|
@@ -235,10 +238,17 @@ class InstrumentationSettings:
|
|
|
235
238
|
result.append(otel_message)
|
|
236
239
|
return result
|
|
237
240
|
|
|
238
|
-
def handle_messages(
|
|
241
|
+
def handle_messages(
|
|
242
|
+
self,
|
|
243
|
+
input_messages: list[ModelMessage],
|
|
244
|
+
response: ModelResponse,
|
|
245
|
+
system: str,
|
|
246
|
+
span: Span,
|
|
247
|
+
parameters: ModelRequestParameters | None = None,
|
|
248
|
+
):
|
|
239
249
|
if self.version == 1:
|
|
240
|
-
events = self.messages_to_otel_events(input_messages)
|
|
241
|
-
for event in self.messages_to_otel_events([response]):
|
|
250
|
+
events = self.messages_to_otel_events(input_messages, parameters)
|
|
251
|
+
for event in self.messages_to_otel_events([response], parameters):
|
|
242
252
|
events.append(
|
|
243
253
|
Event(
|
|
244
254
|
'gen_ai.choice',
|
|
@@ -258,7 +268,7 @@ class InstrumentationSettings:
|
|
|
258
268
|
output_messages = self.messages_to_otel_messages([response])
|
|
259
269
|
assert len(output_messages) == 1
|
|
260
270
|
output_message = output_messages[0]
|
|
261
|
-
instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage]
|
|
271
|
+
instructions = InstrumentedModel._get_instructions(input_messages, parameters) # pyright: ignore [reportPrivateUsage]
|
|
262
272
|
system_instructions_attributes = self.system_instructions_attributes(instructions)
|
|
263
273
|
attributes: dict[str, AttributeValue] = {
|
|
264
274
|
'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)),
|
|
@@ -360,7 +370,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
360
370
|
)
|
|
361
371
|
with self._instrument(messages, prepared_settings, prepared_parameters) as finish:
|
|
362
372
|
response = await self.wrapped.request(messages, model_settings, model_request_parameters)
|
|
363
|
-
finish(response)
|
|
373
|
+
finish(response, prepared_parameters)
|
|
364
374
|
return response
|
|
365
375
|
|
|
366
376
|
@asynccontextmanager
|
|
@@ -384,7 +394,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
384
394
|
yield response_stream
|
|
385
395
|
finally:
|
|
386
396
|
if response_stream: # pragma: no branch
|
|
387
|
-
finish(response_stream.get())
|
|
397
|
+
finish(response_stream.get(), prepared_parameters)
|
|
388
398
|
|
|
389
399
|
@contextmanager
|
|
390
400
|
def _instrument(
|
|
@@ -392,7 +402,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
392
402
|
messages: list[ModelMessage],
|
|
393
403
|
model_settings: ModelSettings | None,
|
|
394
404
|
model_request_parameters: ModelRequestParameters,
|
|
395
|
-
) -> Iterator[Callable[[ModelResponse], None]]:
|
|
405
|
+
) -> Iterator[Callable[[ModelResponse, ModelRequestParameters], None]]:
|
|
396
406
|
operation = 'chat'
|
|
397
407
|
span_name = f'{operation} {self.model_name}'
|
|
398
408
|
# TODO Missing attributes:
|
|
@@ -401,7 +411,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
401
411
|
attributes: dict[str, AttributeValue] = {
|
|
402
412
|
'gen_ai.operation.name': operation,
|
|
403
413
|
**self.model_attributes(self.wrapped),
|
|
404
|
-
|
|
414
|
+
**self.model_request_parameters_attributes(model_request_parameters),
|
|
405
415
|
'logfire.json_schema': json.dumps(
|
|
406
416
|
{
|
|
407
417
|
'type': 'object',
|
|
@@ -419,7 +429,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
419
429
|
try:
|
|
420
430
|
with self.instrumentation_settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
|
|
421
431
|
|
|
422
|
-
def finish(response: ModelResponse):
|
|
432
|
+
def finish(response: ModelResponse, parameters: ModelRequestParameters):
|
|
423
433
|
# FallbackModel updates these span attributes.
|
|
424
434
|
attributes.update(getattr(span, 'attributes', {}))
|
|
425
435
|
request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
|
|
@@ -443,7 +453,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
443
453
|
if not span.is_recording():
|
|
444
454
|
return
|
|
445
455
|
|
|
446
|
-
self.instrumentation_settings.handle_messages(messages, response, system, span)
|
|
456
|
+
self.instrumentation_settings.handle_messages(messages, response, system, span, parameters)
|
|
447
457
|
|
|
448
458
|
attributes_to_set = {
|
|
449
459
|
**response.usage.opentelemetry_attributes(),
|
|
@@ -476,7 +486,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
476
486
|
record_metrics()
|
|
477
487
|
|
|
478
488
|
@staticmethod
|
|
479
|
-
def model_attributes(model: Model):
|
|
489
|
+
def model_attributes(model: Model) -> dict[str, AttributeValue]:
|
|
480
490
|
attributes: dict[str, AttributeValue] = {
|
|
481
491
|
GEN_AI_SYSTEM_ATTRIBUTE: model.system,
|
|
482
492
|
GEN_AI_REQUEST_MODEL_ATTRIBUTE: model.model_name,
|
|
@@ -494,6 +504,12 @@ class InstrumentedModel(WrapperModel):
|
|
|
494
504
|
|
|
495
505
|
return attributes
|
|
496
506
|
|
|
507
|
+
@staticmethod
|
|
508
|
+
def model_request_parameters_attributes(
|
|
509
|
+
model_request_parameters: ModelRequestParameters,
|
|
510
|
+
) -> dict[str, AttributeValue]:
|
|
511
|
+
return {'model_request_parameters': json.dumps(InstrumentedModel.serialize_any(model_request_parameters))}
|
|
512
|
+
|
|
497
513
|
@staticmethod
|
|
498
514
|
def event_to_dict(event: Event) -> dict[str, Any]:
|
|
499
515
|
if not event.body:
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -230,7 +230,7 @@ class MistralModel(Model):
|
|
|
230
230
|
try:
|
|
231
231
|
response = await self.client.chat.complete_async(
|
|
232
232
|
model=str(self._model_name),
|
|
233
|
-
messages=self._map_messages(messages),
|
|
233
|
+
messages=self._map_messages(messages, model_request_parameters),
|
|
234
234
|
n=1,
|
|
235
235
|
tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET,
|
|
236
236
|
tool_choice=self._get_tool_choice(model_request_parameters),
|
|
@@ -259,7 +259,7 @@ class MistralModel(Model):
|
|
|
259
259
|
) -> MistralEventStreamAsync[MistralCompletionEvent]:
|
|
260
260
|
"""Create a streaming completion request to the Mistral model."""
|
|
261
261
|
response: MistralEventStreamAsync[MistralCompletionEvent] | None
|
|
262
|
-
mistral_messages = self._map_messages(messages)
|
|
262
|
+
mistral_messages = self._map_messages(messages, model_request_parameters)
|
|
263
263
|
|
|
264
264
|
# TODO(Marcelo): We need to replace the current MistralAI client to use the beta client.
|
|
265
265
|
# See https://docs.mistral.ai/agents/connectors/websearch/ to support web search.
|
|
@@ -523,7 +523,9 @@ class MistralModel(Model):
|
|
|
523
523
|
else:
|
|
524
524
|
assert_never(part)
|
|
525
525
|
|
|
526
|
-
def _map_messages(
|
|
526
|
+
def _map_messages(
|
|
527
|
+
self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters
|
|
528
|
+
) -> list[MistralMessages]:
|
|
527
529
|
"""Just maps a `pydantic_ai.Message` to a `MistralMessage`."""
|
|
528
530
|
mistral_messages: list[MistralMessages] = []
|
|
529
531
|
for message in messages:
|
|
@@ -554,7 +556,7 @@ class MistralModel(Model):
|
|
|
554
556
|
mistral_messages.append(MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls))
|
|
555
557
|
else:
|
|
556
558
|
assert_never(message)
|
|
557
|
-
if instructions := self._get_instructions(messages):
|
|
559
|
+
if instructions := self._get_instructions(messages, model_request_parameters):
|
|
558
560
|
mistral_messages.insert(0, MistralSystemMessage(content=instructions))
|
|
559
561
|
|
|
560
562
|
# Post-process messages to insert fake assistant message after tool message if followed by user message
|