pydantic-ai-slim 0.6.2__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_a2a.py +6 -4
- pydantic_ai/_agent_graph.py +25 -32
- pydantic_ai/_cli.py +3 -3
- pydantic_ai/_output.py +8 -0
- pydantic_ai/_tool_manager.py +3 -0
- pydantic_ai/ag_ui.py +25 -14
- pydantic_ai/{agent.py → agent/__init__.py} +209 -1027
- pydantic_ai/agent/abstract.py +942 -0
- pydantic_ai/agent/wrapper.py +227 -0
- pydantic_ai/direct.py +9 -9
- pydantic_ai/durable_exec/__init__.py +0 -0
- pydantic_ai/durable_exec/temporal/__init__.py +83 -0
- pydantic_ai/durable_exec/temporal/_agent.py +699 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +92 -0
- pydantic_ai/durable_exec/temporal/_logfire.py +48 -0
- pydantic_ai/durable_exec/temporal/_mcp_server.py +145 -0
- pydantic_ai/durable_exec/temporal/_model.py +168 -0
- pydantic_ai/durable_exec/temporal/_run_context.py +50 -0
- pydantic_ai/durable_exec/temporal/_toolset.py +77 -0
- pydantic_ai/ext/aci.py +10 -9
- pydantic_ai/ext/langchain.py +4 -2
- pydantic_ai/mcp.py +203 -75
- pydantic_ai/messages.py +2 -2
- pydantic_ai/models/__init__.py +65 -9
- pydantic_ai/models/anthropic.py +16 -7
- pydantic_ai/models/bedrock.py +8 -5
- pydantic_ai/models/cohere.py +1 -4
- pydantic_ai/models/fallback.py +4 -2
- pydantic_ai/models/function.py +9 -4
- pydantic_ai/models/gemini.py +15 -9
- pydantic_ai/models/google.py +18 -14
- pydantic_ai/models/groq.py +17 -14
- pydantic_ai/models/huggingface.py +18 -12
- pydantic_ai/models/instrumented.py +3 -1
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +12 -18
- pydantic_ai/models/openai.py +29 -26
- pydantic_ai/models/test.py +3 -0
- pydantic_ai/models/wrapper.py +6 -2
- pydantic_ai/profiles/openai.py +1 -1
- pydantic_ai/providers/google.py +7 -7
- pydantic_ai/result.py +21 -55
- pydantic_ai/run.py +357 -0
- pydantic_ai/tools.py +0 -1
- pydantic_ai/toolsets/__init__.py +2 -0
- pydantic_ai/toolsets/_dynamic.py +87 -0
- pydantic_ai/toolsets/abstract.py +23 -3
- pydantic_ai/toolsets/combined.py +19 -4
- pydantic_ai/toolsets/deferred.py +10 -2
- pydantic_ai/toolsets/function.py +23 -8
- pydantic_ai/toolsets/prefixed.py +4 -0
- pydantic_ai/toolsets/wrapper.py +14 -1
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/METADATA +6 -4
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/RECORD +57 -44
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -15,6 +15,7 @@ import anyio.to_thread
|
|
|
15
15
|
from typing_extensions import ParamSpec, assert_never
|
|
16
16
|
|
|
17
17
|
from pydantic_ai import _utils, usage
|
|
18
|
+
from pydantic_ai._run_context import RunContext
|
|
18
19
|
from pydantic_ai.exceptions import UserError
|
|
19
20
|
from pydantic_ai.messages import (
|
|
20
21
|
AudioUrl,
|
|
@@ -230,10 +231,7 @@ class BedrockConverseModel(Model):
|
|
|
230
231
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
231
232
|
|
|
232
233
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
|
|
233
|
-
|
|
234
|
-
if model_request_parameters.output_tools:
|
|
235
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
236
|
-
return tools
|
|
234
|
+
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
237
235
|
|
|
238
236
|
@staticmethod
|
|
239
237
|
def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef:
|
|
@@ -269,10 +267,15 @@ class BedrockConverseModel(Model):
|
|
|
269
267
|
messages: list[ModelMessage],
|
|
270
268
|
model_settings: ModelSettings | None,
|
|
271
269
|
model_request_parameters: ModelRequestParameters,
|
|
270
|
+
run_context: RunContext[Any] | None = None,
|
|
272
271
|
) -> AsyncIterator[StreamedResponse]:
|
|
273
272
|
settings = cast(BedrockModelSettings, model_settings or {})
|
|
274
273
|
response = await self._messages_create(messages, True, settings, model_request_parameters)
|
|
275
|
-
yield BedrockStreamedResponse(
|
|
274
|
+
yield BedrockStreamedResponse(
|
|
275
|
+
model_request_parameters=model_request_parameters,
|
|
276
|
+
_model_name=self.model_name,
|
|
277
|
+
_event_stream=response,
|
|
278
|
+
)
|
|
276
279
|
|
|
277
280
|
async def _process_response(self, response: ConverseResponseTypeDef) -> ModelResponse:
|
|
278
281
|
items: list[ModelResponsePart] = []
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -248,10 +248,7 @@ class CohereModel(Model):
|
|
|
248
248
|
return cohere_messages
|
|
249
249
|
|
|
250
250
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolV2]:
|
|
251
|
-
|
|
252
|
-
if model_request_parameters.output_tools:
|
|
253
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
254
|
-
return tools
|
|
251
|
+
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
255
252
|
|
|
256
253
|
@staticmethod
|
|
257
254
|
def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
|
pydantic_ai/models/fallback.py
CHANGED
|
@@ -3,10 +3,11 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
from collections.abc import AsyncIterator
|
|
4
4
|
from contextlib import AsyncExitStack, asynccontextmanager, suppress
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
|
-
from typing import TYPE_CHECKING, Callable
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
7
7
|
|
|
8
8
|
from opentelemetry.trace import get_current_span
|
|
9
9
|
|
|
10
|
+
from pydantic_ai._run_context import RunContext
|
|
10
11
|
from pydantic_ai.models.instrumented import InstrumentedModel
|
|
11
12
|
|
|
12
13
|
from ..exceptions import FallbackExceptionGroup, ModelHTTPError
|
|
@@ -83,6 +84,7 @@ class FallbackModel(Model):
|
|
|
83
84
|
messages: list[ModelMessage],
|
|
84
85
|
model_settings: ModelSettings | None,
|
|
85
86
|
model_request_parameters: ModelRequestParameters,
|
|
87
|
+
run_context: RunContext[Any] | None = None,
|
|
86
88
|
) -> AsyncIterator[StreamedResponse]:
|
|
87
89
|
"""Try each model in sequence until one succeeds."""
|
|
88
90
|
exceptions: list[Exception] = []
|
|
@@ -92,7 +94,7 @@ class FallbackModel(Model):
|
|
|
92
94
|
async with AsyncExitStack() as stack:
|
|
93
95
|
try:
|
|
94
96
|
response = await stack.enter_async_context(
|
|
95
|
-
model.request_stream(messages, model_settings, customized_model_request_parameters)
|
|
97
|
+
model.request_stream(messages, model_settings, customized_model_request_parameters, run_context)
|
|
96
98
|
)
|
|
97
99
|
except Exception as exc:
|
|
98
100
|
if self._fallback_on(exc):
|
pydantic_ai/models/function.py
CHANGED
|
@@ -7,13 +7,12 @@ from contextlib import asynccontextmanager
|
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from itertools import chain
|
|
10
|
-
from typing import Callable, Union
|
|
10
|
+
from typing import Any, Callable, Union
|
|
11
11
|
|
|
12
12
|
from typing_extensions import TypeAlias, assert_never, overload
|
|
13
13
|
|
|
14
|
-
from pydantic_ai.profiles import ModelProfileSpec
|
|
15
|
-
|
|
16
14
|
from .. import _utils, usage
|
|
15
|
+
from .._run_context import RunContext
|
|
17
16
|
from .._utils import PeekableAsyncStream
|
|
18
17
|
from ..messages import (
|
|
19
18
|
BinaryContent,
|
|
@@ -32,6 +31,7 @@ from ..messages import (
|
|
|
32
31
|
UserContent,
|
|
33
32
|
UserPromptPart,
|
|
34
33
|
)
|
|
34
|
+
from ..profiles import ModelProfileSpec
|
|
35
35
|
from ..settings import ModelSettings
|
|
36
36
|
from ..tools import ToolDefinition
|
|
37
37
|
from . import Model, ModelRequestParameters, StreamedResponse
|
|
@@ -147,6 +147,7 @@ class FunctionModel(Model):
|
|
|
147
147
|
messages: list[ModelMessage],
|
|
148
148
|
model_settings: ModelSettings | None,
|
|
149
149
|
model_request_parameters: ModelRequestParameters,
|
|
150
|
+
run_context: RunContext[Any] | None = None,
|
|
150
151
|
) -> AsyncIterator[StreamedResponse]:
|
|
151
152
|
agent_info = AgentInfo(
|
|
152
153
|
model_request_parameters.function_tools,
|
|
@@ -165,7 +166,11 @@ class FunctionModel(Model):
|
|
|
165
166
|
if isinstance(first, _utils.Unset):
|
|
166
167
|
raise ValueError('Stream function must return at least one item')
|
|
167
168
|
|
|
168
|
-
yield FunctionStreamedResponse(
|
|
169
|
+
yield FunctionStreamedResponse(
|
|
170
|
+
model_request_parameters=model_request_parameters,
|
|
171
|
+
_model_name=self._model_name,
|
|
172
|
+
_iter=response_stream,
|
|
173
|
+
)
|
|
169
174
|
|
|
170
175
|
@property
|
|
171
176
|
def model_name(self) -> str:
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -13,10 +13,9 @@ import pydantic
|
|
|
13
13
|
from httpx import USE_CLIENT_DEFAULT, Response as HTTPResponse
|
|
14
14
|
from typing_extensions import NotRequired, TypedDict, assert_never, deprecated
|
|
15
15
|
|
|
16
|
-
from pydantic_ai.providers import Provider, infer_provider
|
|
17
|
-
|
|
18
16
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
19
17
|
from .._output import OutputObjectDefinition
|
|
18
|
+
from .._run_context import RunContext
|
|
20
19
|
from ..exceptions import UserError
|
|
21
20
|
from ..messages import (
|
|
22
21
|
BinaryContent,
|
|
@@ -38,6 +37,7 @@ from ..messages import (
|
|
|
38
37
|
VideoUrl,
|
|
39
38
|
)
|
|
40
39
|
from ..profiles import ModelProfileSpec
|
|
40
|
+
from ..providers import Provider, infer_provider
|
|
41
41
|
from ..settings import ModelSettings
|
|
42
42
|
from ..tools import ToolDefinition
|
|
43
43
|
from . import (
|
|
@@ -167,12 +167,13 @@ class GeminiModel(Model):
|
|
|
167
167
|
messages: list[ModelMessage],
|
|
168
168
|
model_settings: ModelSettings | None,
|
|
169
169
|
model_request_parameters: ModelRequestParameters,
|
|
170
|
+
run_context: RunContext[Any] | None = None,
|
|
170
171
|
) -> AsyncIterator[StreamedResponse]:
|
|
171
172
|
check_allow_model_requests()
|
|
172
173
|
async with self._make_request(
|
|
173
174
|
messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
|
|
174
175
|
) as http_response:
|
|
175
|
-
yield await self._process_streamed_response(http_response)
|
|
176
|
+
yield await self._process_streamed_response(http_response, model_request_parameters)
|
|
176
177
|
|
|
177
178
|
@property
|
|
178
179
|
def model_name(self) -> GeminiModelName:
|
|
@@ -185,9 +186,7 @@ class GeminiModel(Model):
|
|
|
185
186
|
return self._system
|
|
186
187
|
|
|
187
188
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
|
|
188
|
-
tools = [_function_from_abstract_tool(t) for t in model_request_parameters.
|
|
189
|
-
if model_request_parameters.output_tools:
|
|
190
|
-
tools += [_function_from_abstract_tool(t) for t in model_request_parameters.output_tools]
|
|
189
|
+
tools = [_function_from_abstract_tool(t) for t in model_request_parameters.tool_defs.values()]
|
|
191
190
|
return _GeminiTools(function_declarations=tools) if tools else None
|
|
192
191
|
|
|
193
192
|
def _get_tool_config(
|
|
@@ -288,7 +287,9 @@ class GeminiModel(Model):
|
|
|
288
287
|
vendor_details=vendor_details,
|
|
289
288
|
)
|
|
290
289
|
|
|
291
|
-
async def _process_streamed_response(
|
|
290
|
+
async def _process_streamed_response(
|
|
291
|
+
self, http_response: HTTPResponse, model_request_parameters: ModelRequestParameters
|
|
292
|
+
) -> StreamedResponse:
|
|
292
293
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
293
294
|
aiter_bytes = http_response.aiter_bytes()
|
|
294
295
|
start_response: _GeminiResponse | None = None
|
|
@@ -309,7 +310,12 @@ class GeminiModel(Model):
|
|
|
309
310
|
if start_response is None:
|
|
310
311
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
311
312
|
|
|
312
|
-
return GeminiStreamedResponse(
|
|
313
|
+
return GeminiStreamedResponse(
|
|
314
|
+
model_request_parameters=model_request_parameters,
|
|
315
|
+
_model_name=self._model_name,
|
|
316
|
+
_content=content,
|
|
317
|
+
_stream=aiter_bytes,
|
|
318
|
+
)
|
|
313
319
|
|
|
314
320
|
async def _message_to_gemini_content(
|
|
315
321
|
self, messages: list[ModelMessage]
|
|
@@ -872,7 +878,7 @@ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
|
|
|
872
878
|
metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details)
|
|
873
879
|
suffix = key.removesuffix('_details')
|
|
874
880
|
for detail in metadata_details:
|
|
875
|
-
details[f'{detail["modality"].lower()}_{suffix}'] = detail
|
|
881
|
+
details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
|
|
876
882
|
|
|
877
883
|
return usage.Usage(
|
|
878
884
|
request_tokens=metadata.get('prompt_token_count', 0),
|
pydantic_ai/models/google.py
CHANGED
|
@@ -8,11 +8,11 @@ from datetime import datetime
|
|
|
8
8
|
from typing import Any, Literal, Union, cast, overload
|
|
9
9
|
from uuid import uuid4
|
|
10
10
|
|
|
11
|
-
from google.genai.types import ExecutableCodeDict
|
|
12
11
|
from typing_extensions import assert_never
|
|
13
12
|
|
|
14
13
|
from .. import UnexpectedModelBehavior, _utils, usage
|
|
15
14
|
from .._output import OutputObjectDefinition
|
|
15
|
+
from .._run_context import RunContext
|
|
16
16
|
from ..builtin_tools import CodeExecutionTool, WebSearchTool
|
|
17
17
|
from ..exceptions import UserError
|
|
18
18
|
from ..messages import (
|
|
@@ -48,10 +48,11 @@ from . import (
|
|
|
48
48
|
)
|
|
49
49
|
|
|
50
50
|
try:
|
|
51
|
-
from google import
|
|
51
|
+
from google.genai import Client
|
|
52
52
|
from google.genai.types import (
|
|
53
53
|
ContentDict,
|
|
54
54
|
ContentUnionDict,
|
|
55
|
+
ExecutableCodeDict,
|
|
55
56
|
FunctionCallDict,
|
|
56
57
|
FunctionCallingConfigDict,
|
|
57
58
|
FunctionCallingConfigMode,
|
|
@@ -136,10 +137,10 @@ class GoogleModel(Model):
|
|
|
136
137
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
137
138
|
"""
|
|
138
139
|
|
|
139
|
-
client:
|
|
140
|
+
client: Client = field(repr=False)
|
|
140
141
|
|
|
141
142
|
_model_name: GoogleModelName = field(repr=False)
|
|
142
|
-
_provider: Provider[
|
|
143
|
+
_provider: Provider[Client] = field(repr=False)
|
|
143
144
|
_url: str | None = field(repr=False)
|
|
144
145
|
_system: str = field(default='google', repr=False)
|
|
145
146
|
|
|
@@ -147,7 +148,7 @@ class GoogleModel(Model):
|
|
|
147
148
|
self,
|
|
148
149
|
model_name: GoogleModelName,
|
|
149
150
|
*,
|
|
150
|
-
provider: Literal['google-gla', 'google-vertex'] | Provider[
|
|
151
|
+
provider: Literal['google-gla', 'google-vertex'] | Provider[Client] = 'google-gla',
|
|
151
152
|
profile: ModelProfileSpec | None = None,
|
|
152
153
|
settings: ModelSettings | None = None,
|
|
153
154
|
):
|
|
@@ -193,11 +194,12 @@ class GoogleModel(Model):
|
|
|
193
194
|
messages: list[ModelMessage],
|
|
194
195
|
model_settings: ModelSettings | None,
|
|
195
196
|
model_request_parameters: ModelRequestParameters,
|
|
197
|
+
run_context: RunContext[Any] | None = None,
|
|
196
198
|
) -> AsyncIterator[StreamedResponse]:
|
|
197
199
|
check_allow_model_requests()
|
|
198
200
|
model_settings = cast(GoogleModelSettings, model_settings or {})
|
|
199
201
|
response = await self._generate_content(messages, True, model_settings, model_request_parameters)
|
|
200
|
-
yield await self._process_streamed_response(response) # type: ignore
|
|
202
|
+
yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
|
|
201
203
|
|
|
202
204
|
@property
|
|
203
205
|
def model_name(self) -> GoogleModelName:
|
|
@@ -212,18 +214,17 @@ class GoogleModel(Model):
|
|
|
212
214
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
|
|
213
215
|
tools: list[ToolDict] = [
|
|
214
216
|
ToolDict(function_declarations=[_function_declaration_from_tool(t)])
|
|
215
|
-
for t in model_request_parameters.
|
|
217
|
+
for t in model_request_parameters.tool_defs.values()
|
|
216
218
|
]
|
|
217
|
-
if model_request_parameters.output_tools:
|
|
218
|
-
tools += [
|
|
219
|
-
ToolDict(function_declarations=[_function_declaration_from_tool(t)])
|
|
220
|
-
for t in model_request_parameters.output_tools
|
|
221
|
-
]
|
|
222
219
|
for tool in model_request_parameters.builtin_tools:
|
|
223
220
|
if isinstance(tool, WebSearchTool):
|
|
224
221
|
tools.append(ToolDict(google_search=GoogleSearchDict()))
|
|
225
222
|
elif isinstance(tool, CodeExecutionTool): # pragma: no branch
|
|
226
223
|
tools.append(ToolDict(code_execution=ToolCodeExecutionDict()))
|
|
224
|
+
else: # pragma: no cover
|
|
225
|
+
raise UserError(
|
|
226
|
+
f'`{tool.__class__.__name__}` is not supported by `GoogleModel`. If it should be, please file an issue.'
|
|
227
|
+
)
|
|
227
228
|
return tools or None
|
|
228
229
|
|
|
229
230
|
def _get_tool_config(
|
|
@@ -336,7 +337,9 @@ class GoogleModel(Model):
|
|
|
336
337
|
parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details
|
|
337
338
|
)
|
|
338
339
|
|
|
339
|
-
async def _process_streamed_response(
|
|
340
|
+
async def _process_streamed_response(
|
|
341
|
+
self, response: AsyncIterator[GenerateContentResponse], model_request_parameters: ModelRequestParameters
|
|
342
|
+
) -> StreamedResponse:
|
|
340
343
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
341
344
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
342
345
|
first_chunk = await peekable_response.peek()
|
|
@@ -344,6 +347,7 @@ class GoogleModel(Model):
|
|
|
344
347
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') # pragma: no cover
|
|
345
348
|
|
|
346
349
|
return GeminiStreamedResponse(
|
|
350
|
+
model_request_parameters=model_request_parameters,
|
|
347
351
|
_model_name=self._model_name,
|
|
348
352
|
_response=peekable_response,
|
|
349
353
|
_timestamp=first_chunk.create_time or _utils.now_utc(),
|
|
@@ -603,7 +607,7 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage:
|
|
|
603
607
|
if key.endswith('_details') and metadata_details:
|
|
604
608
|
suffix = key.removesuffix('_details')
|
|
605
609
|
for detail in metadata_details:
|
|
606
|
-
details[f'{detail["modality"].lower()}_{suffix}'] = detail
|
|
610
|
+
details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
|
|
607
611
|
|
|
608
612
|
return usage.Usage(
|
|
609
613
|
request_tokens=metadata.get('prompt_token_count', 0),
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -5,17 +5,16 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import Literal, Union, cast, overload
|
|
8
|
+
from typing import Any, Literal, Union, cast, overload
|
|
9
9
|
|
|
10
10
|
from typing_extensions import assert_never
|
|
11
11
|
|
|
12
|
-
from pydantic_ai._thinking_part import split_content_into_text_and_thinking
|
|
13
|
-
from pydantic_ai.exceptions import UserError
|
|
14
|
-
from pydantic_ai.profiles.groq import GroqModelProfile
|
|
15
|
-
|
|
16
12
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
13
|
+
from .._run_context import RunContext
|
|
14
|
+
from .._thinking_part import split_content_into_text_and_thinking
|
|
17
15
|
from .._utils import generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id, number_to_datetime
|
|
18
|
-
from ..builtin_tools import
|
|
16
|
+
from ..builtin_tools import WebSearchTool
|
|
17
|
+
from ..exceptions import UserError
|
|
19
18
|
from ..messages import (
|
|
20
19
|
BinaryContent,
|
|
21
20
|
BuiltinToolCallPart,
|
|
@@ -36,6 +35,7 @@ from ..messages import (
|
|
|
36
35
|
UserPromptPart,
|
|
37
36
|
)
|
|
38
37
|
from ..profiles import ModelProfile, ModelProfileSpec
|
|
38
|
+
from ..profiles.groq import GroqModelProfile
|
|
39
39
|
from ..providers import Provider, infer_provider
|
|
40
40
|
from ..settings import ModelSettings
|
|
41
41
|
from ..tools import ToolDefinition
|
|
@@ -171,13 +171,14 @@ class GroqModel(Model):
|
|
|
171
171
|
messages: list[ModelMessage],
|
|
172
172
|
model_settings: ModelSettings | None,
|
|
173
173
|
model_request_parameters: ModelRequestParameters,
|
|
174
|
+
run_context: RunContext[Any] | None = None,
|
|
174
175
|
) -> AsyncIterator[StreamedResponse]:
|
|
175
176
|
check_allow_model_requests()
|
|
176
177
|
response = await self._completions_create(
|
|
177
178
|
messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters
|
|
178
179
|
)
|
|
179
180
|
async with response:
|
|
180
|
-
yield await self._process_streamed_response(response)
|
|
181
|
+
yield await self._process_streamed_response(response, model_request_parameters)
|
|
181
182
|
|
|
182
183
|
@property
|
|
183
184
|
def model_name(self) -> GroqModelName:
|
|
@@ -287,7 +288,9 @@ class GroqModel(Model):
|
|
|
287
288
|
items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
|
|
288
289
|
)
|
|
289
290
|
|
|
290
|
-
async def _process_streamed_response(
|
|
291
|
+
async def _process_streamed_response(
|
|
292
|
+
self, response: AsyncStream[chat.ChatCompletionChunk], model_request_parameters: ModelRequestParameters
|
|
293
|
+
) -> GroqStreamedResponse:
|
|
291
294
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
292
295
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
293
296
|
first_chunk = await peekable_response.peek()
|
|
@@ -297,6 +300,7 @@ class GroqModel(Model):
|
|
|
297
300
|
)
|
|
298
301
|
|
|
299
302
|
return GroqStreamedResponse(
|
|
303
|
+
model_request_parameters=model_request_parameters,
|
|
300
304
|
_response=peekable_response,
|
|
301
305
|
_model_name=self._model_name,
|
|
302
306
|
_model_profile=self.profile,
|
|
@@ -304,10 +308,7 @@ class GroqModel(Model):
|
|
|
304
308
|
)
|
|
305
309
|
|
|
306
310
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
307
|
-
|
|
308
|
-
if model_request_parameters.output_tools:
|
|
309
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
310
|
-
return tools
|
|
311
|
+
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
311
312
|
|
|
312
313
|
def _get_builtin_tools(
|
|
313
314
|
self, model_request_parameters: ModelRequestParameters
|
|
@@ -317,8 +318,10 @@ class GroqModel(Model):
|
|
|
317
318
|
if isinstance(tool, WebSearchTool):
|
|
318
319
|
if not GroqModelProfile.from_profile(self.profile).groq_always_has_web_search_builtin_tool:
|
|
319
320
|
raise UserError('`WebSearchTool` is not supported by Groq') # pragma: no cover
|
|
320
|
-
|
|
321
|
-
raise UserError(
|
|
321
|
+
else:
|
|
322
|
+
raise UserError(
|
|
323
|
+
f'`{tool.__class__.__name__}` is not supported by `GroqModel`. If it should be, please file an issue.'
|
|
324
|
+
)
|
|
322
325
|
return tools
|
|
323
326
|
|
|
324
327
|
def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
|
|
@@ -5,16 +5,15 @@ from collections.abc import AsyncIterable, AsyncIterator
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime, timezone
|
|
8
|
-
from typing import Literal, Union, cast, overload
|
|
8
|
+
from typing import Any, Literal, Union, cast, overload
|
|
9
9
|
|
|
10
10
|
from typing_extensions import assert_never
|
|
11
11
|
|
|
12
|
-
from pydantic_ai._thinking_part import split_content_into_text_and_thinking
|
|
13
|
-
from pydantic_ai.exceptions import UserError
|
|
14
|
-
from pydantic_ai.providers import Provider, infer_provider
|
|
15
|
-
|
|
16
12
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
13
|
+
from .._run_context import RunContext
|
|
14
|
+
from .._thinking_part import split_content_into_text_and_thinking
|
|
17
15
|
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc
|
|
16
|
+
from ..exceptions import UserError
|
|
18
17
|
from ..messages import (
|
|
19
18
|
AudioUrl,
|
|
20
19
|
BinaryContent,
|
|
@@ -37,9 +36,15 @@ from ..messages import (
|
|
|
37
36
|
VideoUrl,
|
|
38
37
|
)
|
|
39
38
|
from ..profiles import ModelProfile
|
|
39
|
+
from ..providers import Provider, infer_provider
|
|
40
40
|
from ..settings import ModelSettings
|
|
41
41
|
from ..tools import ToolDefinition
|
|
42
|
-
from . import
|
|
42
|
+
from . import (
|
|
43
|
+
Model,
|
|
44
|
+
ModelRequestParameters,
|
|
45
|
+
StreamedResponse,
|
|
46
|
+
check_allow_model_requests,
|
|
47
|
+
)
|
|
43
48
|
|
|
44
49
|
try:
|
|
45
50
|
import aiohttp
|
|
@@ -150,12 +155,13 @@ class HuggingFaceModel(Model):
|
|
|
150
155
|
messages: list[ModelMessage],
|
|
151
156
|
model_settings: ModelSettings | None,
|
|
152
157
|
model_request_parameters: ModelRequestParameters,
|
|
158
|
+
run_context: RunContext[Any] | None = None,
|
|
153
159
|
) -> AsyncIterator[StreamedResponse]:
|
|
154
160
|
check_allow_model_requests()
|
|
155
161
|
response = await self._completions_create(
|
|
156
162
|
messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
|
|
157
163
|
)
|
|
158
|
-
yield await self._process_streamed_response(response)
|
|
164
|
+
yield await self._process_streamed_response(response, model_request_parameters)
|
|
159
165
|
|
|
160
166
|
@property
|
|
161
167
|
def model_name(self) -> HuggingFaceModelName:
|
|
@@ -263,7 +269,9 @@ class HuggingFaceModel(Model):
|
|
|
263
269
|
vendor_id=response.id,
|
|
264
270
|
)
|
|
265
271
|
|
|
266
|
-
async def _process_streamed_response(
|
|
272
|
+
async def _process_streamed_response(
|
|
273
|
+
self, response: AsyncIterable[ChatCompletionStreamOutput], model_request_parameters: ModelRequestParameters
|
|
274
|
+
) -> StreamedResponse:
|
|
267
275
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
268
276
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
269
277
|
first_chunk = await peekable_response.peek()
|
|
@@ -273,6 +281,7 @@ class HuggingFaceModel(Model):
|
|
|
273
281
|
)
|
|
274
282
|
|
|
275
283
|
return HuggingFaceStreamedResponse(
|
|
284
|
+
model_request_parameters=model_request_parameters,
|
|
276
285
|
_model_name=self._model_name,
|
|
277
286
|
_model_profile=self.profile,
|
|
278
287
|
_response=peekable_response,
|
|
@@ -280,10 +289,7 @@ class HuggingFaceModel(Model):
|
|
|
280
289
|
)
|
|
281
290
|
|
|
282
291
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]:
|
|
283
|
-
|
|
284
|
-
if model_request_parameters.output_tools:
|
|
285
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
286
|
-
return tools
|
|
292
|
+
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
287
293
|
|
|
288
294
|
async def _map_messages(
|
|
289
295
|
self, messages: list[ModelMessage]
|
|
@@ -18,6 +18,7 @@ from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provide
|
|
|
18
18
|
from opentelemetry.util.types import AttributeValue
|
|
19
19
|
from pydantic import TypeAdapter
|
|
20
20
|
|
|
21
|
+
from .._run_context import RunContext
|
|
21
22
|
from ..messages import ModelMessage, ModelRequest, ModelResponse
|
|
22
23
|
from ..settings import ModelSettings
|
|
23
24
|
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
|
|
@@ -218,12 +219,13 @@ class InstrumentedModel(WrapperModel):
|
|
|
218
219
|
messages: list[ModelMessage],
|
|
219
220
|
model_settings: ModelSettings | None,
|
|
220
221
|
model_request_parameters: ModelRequestParameters,
|
|
222
|
+
run_context: RunContext[Any] | None = None,
|
|
221
223
|
) -> AsyncIterator[StreamedResponse]:
|
|
222
224
|
with self._instrument(messages, model_settings, model_request_parameters) as finish:
|
|
223
225
|
response_stream: StreamedResponse | None = None
|
|
224
226
|
try:
|
|
225
227
|
async with super().request_stream(
|
|
226
|
-
messages, model_settings, model_request_parameters
|
|
228
|
+
messages, model_settings, model_request_parameters, run_context
|
|
227
229
|
) as response_stream:
|
|
228
230
|
yield response_stream
|
|
229
231
|
finally:
|
|
@@ -3,9 +3,10 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
from collections.abc import AsyncIterator
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass
|
|
6
|
-
from typing import TYPE_CHECKING, cast
|
|
6
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
7
7
|
|
|
8
8
|
from .. import _mcp, exceptions, usage
|
|
9
|
+
from .._run_context import RunContext
|
|
9
10
|
from ..messages import ModelMessage, ModelResponse
|
|
10
11
|
from ..settings import ModelSettings
|
|
11
12
|
from . import Model, ModelRequestParameters, StreamedResponse
|
|
@@ -76,6 +77,7 @@ class MCPSamplingModel(Model):
|
|
|
76
77
|
messages: list[ModelMessage],
|
|
77
78
|
model_settings: ModelSettings | None,
|
|
78
79
|
model_request_parameters: ModelRequestParameters,
|
|
80
|
+
run_context: RunContext[Any] | None = None,
|
|
79
81
|
) -> AsyncIterator[StreamedResponse]:
|
|
80
82
|
raise NotImplementedError('MCP Sampling does not support streaming')
|
|
81
83
|
yield
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -11,11 +11,11 @@ import pydantic_core
|
|
|
11
11
|
from httpx import Timeout
|
|
12
12
|
from typing_extensions import assert_never
|
|
13
13
|
|
|
14
|
-
from pydantic_ai._thinking_part import split_content_into_text_and_thinking
|
|
15
|
-
from pydantic_ai.exceptions import UserError
|
|
16
|
-
|
|
17
14
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
|
|
15
|
+
from .._run_context import RunContext
|
|
16
|
+
from .._thinking_part import split_content_into_text_and_thinking
|
|
18
17
|
from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime
|
|
18
|
+
from ..exceptions import UserError
|
|
19
19
|
from ..messages import (
|
|
20
20
|
BinaryContent,
|
|
21
21
|
BuiltinToolCallPart,
|
|
@@ -176,6 +176,7 @@ class MistralModel(Model):
|
|
|
176
176
|
messages: list[ModelMessage],
|
|
177
177
|
model_settings: ModelSettings | None,
|
|
178
178
|
model_request_parameters: ModelRequestParameters,
|
|
179
|
+
run_context: RunContext[Any] | None = None,
|
|
179
180
|
) -> AsyncIterator[StreamedResponse]:
|
|
180
181
|
"""Make a streaming request to the model from Pydantic AI call."""
|
|
181
182
|
check_allow_model_requests()
|
|
@@ -183,7 +184,7 @@ class MistralModel(Model):
|
|
|
183
184
|
messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
|
|
184
185
|
)
|
|
185
186
|
async with response:
|
|
186
|
-
yield await self._process_streamed_response(
|
|
187
|
+
yield await self._process_streamed_response(response, model_request_parameters)
|
|
187
188
|
|
|
188
189
|
@property
|
|
189
190
|
def model_name(self) -> MistralModelName:
|
|
@@ -246,11 +247,7 @@ class MistralModel(Model):
|
|
|
246
247
|
if model_request_parameters.builtin_tools:
|
|
247
248
|
raise UserError('Mistral does not support built-in tools')
|
|
248
249
|
|
|
249
|
-
if
|
|
250
|
-
model_request_parameters.output_tools
|
|
251
|
-
and model_request_parameters.function_tools
|
|
252
|
-
or model_request_parameters.function_tools
|
|
253
|
-
):
|
|
250
|
+
if model_request_parameters.function_tools:
|
|
254
251
|
# Function Calling
|
|
255
252
|
response = await self.client.chat.stream_async(
|
|
256
253
|
model=str(self._model_name),
|
|
@@ -318,16 +315,13 @@ class MistralModel(Model):
|
|
|
318
315
|
|
|
319
316
|
Returns None if both function_tools and output_tools are empty.
|
|
320
317
|
"""
|
|
321
|
-
all_tools: list[ToolDefinition] = (
|
|
322
|
-
model_request_parameters.function_tools + model_request_parameters.output_tools
|
|
323
|
-
)
|
|
324
318
|
tools = [
|
|
325
319
|
MistralTool(
|
|
326
320
|
function=MistralFunction(
|
|
327
321
|
name=r.name, parameters=r.parameters_json_schema, description=r.description or ''
|
|
328
322
|
)
|
|
329
323
|
)
|
|
330
|
-
for r in
|
|
324
|
+
for r in model_request_parameters.tool_defs.values()
|
|
331
325
|
]
|
|
332
326
|
return tools if tools else None
|
|
333
327
|
|
|
@@ -359,8 +353,8 @@ class MistralModel(Model):
|
|
|
359
353
|
|
|
360
354
|
async def _process_streamed_response(
|
|
361
355
|
self,
|
|
362
|
-
output_tools: list[ToolDefinition],
|
|
363
356
|
response: MistralEventStreamAsync[MistralCompletionEvent],
|
|
357
|
+
model_request_parameters: ModelRequestParameters,
|
|
364
358
|
) -> StreamedResponse:
|
|
365
359
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
366
360
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
@@ -376,10 +370,10 @@ class MistralModel(Model):
|
|
|
376
370
|
timestamp = _now_utc()
|
|
377
371
|
|
|
378
372
|
return MistralStreamedResponse(
|
|
373
|
+
model_request_parameters=model_request_parameters,
|
|
379
374
|
_response=peekable_response,
|
|
380
375
|
_model_name=self._model_name,
|
|
381
376
|
_timestamp=timestamp,
|
|
382
|
-
_output_tools={c.name: c for c in output_tools},
|
|
383
377
|
)
|
|
384
378
|
|
|
385
379
|
@staticmethod
|
|
@@ -586,7 +580,6 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
586
580
|
_model_name: MistralModelName
|
|
587
581
|
_response: AsyncIterable[MistralCompletionEvent]
|
|
588
582
|
_timestamp: datetime
|
|
589
|
-
_output_tools: dict[str, ToolDefinition]
|
|
590
583
|
|
|
591
584
|
_delta_content: str = field(default='', init=False)
|
|
592
585
|
|
|
@@ -605,10 +598,11 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
605
598
|
text = _map_content(content)
|
|
606
599
|
if text:
|
|
607
600
|
# Attempt to produce an output tool call from the received text
|
|
608
|
-
|
|
601
|
+
output_tools = {c.name: c for c in self.model_request_parameters.output_tools}
|
|
602
|
+
if output_tools:
|
|
609
603
|
self._delta_content += text
|
|
610
604
|
# TODO: Port to native "manual JSON" mode
|
|
611
|
-
maybe_tool_call_part = self._try_get_output_tool_from_text(self._delta_content,
|
|
605
|
+
maybe_tool_call_part = self._try_get_output_tool_from_text(self._delta_content, output_tools)
|
|
612
606
|
if maybe_tool_call_part:
|
|
613
607
|
yield self._parts_manager.handle_tool_call_part(
|
|
614
608
|
vendor_part_id='output',
|