pydantic-ai-slim 0.6.2__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_a2a.py +6 -4
- pydantic_ai/_agent_graph.py +37 -37
- pydantic_ai/_cli.py +3 -3
- pydantic_ai/_output.py +8 -0
- pydantic_ai/_tool_manager.py +3 -0
- pydantic_ai/ag_ui.py +25 -14
- pydantic_ai/{agent.py → agent/__init__.py} +209 -1027
- pydantic_ai/agent/abstract.py +942 -0
- pydantic_ai/agent/wrapper.py +227 -0
- pydantic_ai/direct.py +9 -9
- pydantic_ai/durable_exec/__init__.py +0 -0
- pydantic_ai/durable_exec/temporal/__init__.py +83 -0
- pydantic_ai/durable_exec/temporal/_agent.py +699 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +92 -0
- pydantic_ai/durable_exec/temporal/_logfire.py +48 -0
- pydantic_ai/durable_exec/temporal/_mcp_server.py +145 -0
- pydantic_ai/durable_exec/temporal/_model.py +168 -0
- pydantic_ai/durable_exec/temporal/_run_context.py +50 -0
- pydantic_ai/durable_exec/temporal/_toolset.py +77 -0
- pydantic_ai/ext/aci.py +10 -9
- pydantic_ai/ext/langchain.py +4 -2
- pydantic_ai/mcp.py +203 -75
- pydantic_ai/messages.py +2 -2
- pydantic_ai/models/__init__.py +93 -9
- pydantic_ai/models/anthropic.py +16 -7
- pydantic_ai/models/bedrock.py +8 -5
- pydantic_ai/models/cohere.py +1 -4
- pydantic_ai/models/fallback.py +10 -3
- pydantic_ai/models/function.py +9 -4
- pydantic_ai/models/gemini.py +15 -9
- pydantic_ai/models/google.py +84 -20
- pydantic_ai/models/groq.py +17 -14
- pydantic_ai/models/huggingface.py +18 -12
- pydantic_ai/models/instrumented.py +3 -1
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +12 -18
- pydantic_ai/models/openai.py +57 -30
- pydantic_ai/models/test.py +3 -0
- pydantic_ai/models/wrapper.py +6 -2
- pydantic_ai/profiles/openai.py +1 -1
- pydantic_ai/providers/google.py +7 -7
- pydantic_ai/result.py +21 -55
- pydantic_ai/run.py +357 -0
- pydantic_ai/tools.py +0 -1
- pydantic_ai/toolsets/__init__.py +2 -0
- pydantic_ai/toolsets/_dynamic.py +87 -0
- pydantic_ai/toolsets/abstract.py +23 -3
- pydantic_ai/toolsets/combined.py +19 -4
- pydantic_ai/toolsets/deferred.py +10 -2
- pydantic_ai/toolsets/function.py +23 -8
- pydantic_ai/toolsets/prefixed.py +4 -0
- pydantic_ai/toolsets/wrapper.py +14 -1
- pydantic_ai/usage.py +17 -1
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.1.dist-info}/METADATA +7 -5
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.1.dist-info}/RECORD +58 -45
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.1.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.1.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.1.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -21,7 +21,9 @@ from typing_extensions import assert_never
|
|
|
21
21
|
from pydantic_ai.builtin_tools import CodeExecutionTool, WebSearchTool
|
|
22
22
|
|
|
23
23
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
24
|
+
from .._run_context import RunContext
|
|
24
25
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
26
|
+
from ..exceptions import UserError
|
|
25
27
|
from ..messages import (
|
|
26
28
|
BinaryContent,
|
|
27
29
|
BuiltinToolCallPart,
|
|
@@ -196,13 +198,14 @@ class AnthropicModel(Model):
|
|
|
196
198
|
messages: list[ModelMessage],
|
|
197
199
|
model_settings: ModelSettings | None,
|
|
198
200
|
model_request_parameters: ModelRequestParameters,
|
|
201
|
+
run_context: RunContext[Any] | None = None,
|
|
199
202
|
) -> AsyncIterator[StreamedResponse]:
|
|
200
203
|
check_allow_model_requests()
|
|
201
204
|
response = await self._messages_create(
|
|
202
205
|
messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
|
|
203
206
|
)
|
|
204
207
|
async with response:
|
|
205
|
-
yield await self._process_streamed_response(response)
|
|
208
|
+
yield await self._process_streamed_response(response, model_request_parameters)
|
|
206
209
|
|
|
207
210
|
@property
|
|
208
211
|
def model_name(self) -> AnthropicModelName:
|
|
@@ -329,7 +332,9 @@ class AnthropicModel(Model):
|
|
|
329
332
|
|
|
330
333
|
return ModelResponse(items, usage=_map_usage(response), model_name=response.model, vendor_id=response.id)
|
|
331
334
|
|
|
332
|
-
async def _process_streamed_response(
|
|
335
|
+
async def _process_streamed_response(
|
|
336
|
+
self, response: AsyncStream[BetaRawMessageStreamEvent], model_request_parameters: ModelRequestParameters
|
|
337
|
+
) -> StreamedResponse:
|
|
333
338
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
334
339
|
first_chunk = await peekable_response.peek()
|
|
335
340
|
if isinstance(first_chunk, _utils.Unset):
|
|
@@ -338,14 +343,14 @@ class AnthropicModel(Model):
|
|
|
338
343
|
# Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
|
|
339
344
|
timestamp = datetime.now(tz=timezone.utc)
|
|
340
345
|
return AnthropicStreamedResponse(
|
|
341
|
-
|
|
346
|
+
model_request_parameters=model_request_parameters,
|
|
347
|
+
_model_name=self._model_name,
|
|
348
|
+
_response=peekable_response,
|
|
349
|
+
_timestamp=timestamp,
|
|
342
350
|
)
|
|
343
351
|
|
|
344
352
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolParam]:
|
|
345
|
-
|
|
346
|
-
if model_request_parameters.output_tools:
|
|
347
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
348
|
-
return tools
|
|
353
|
+
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
349
354
|
|
|
350
355
|
def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolUnionParam]:
|
|
351
356
|
tools: list[BetaToolUnionParam] = []
|
|
@@ -363,6 +368,10 @@ class AnthropicModel(Model):
|
|
|
363
368
|
)
|
|
364
369
|
elif isinstance(tool, CodeExecutionTool): # pragma: no branch
|
|
365
370
|
tools.append(BetaCodeExecutionTool20250522Param(name='code_execution', type='code_execution_20250522'))
|
|
371
|
+
else: # pragma: no cover
|
|
372
|
+
raise UserError(
|
|
373
|
+
f'`{tool.__class__.__name__}` is not supported by `AnthropicModel`. If it should be, please file an issue.'
|
|
374
|
+
)
|
|
366
375
|
return tools
|
|
367
376
|
|
|
368
377
|
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901
|
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -15,6 +15,7 @@ import anyio.to_thread
|
|
|
15
15
|
from typing_extensions import ParamSpec, assert_never
|
|
16
16
|
|
|
17
17
|
from pydantic_ai import _utils, usage
|
|
18
|
+
from pydantic_ai._run_context import RunContext
|
|
18
19
|
from pydantic_ai.exceptions import UserError
|
|
19
20
|
from pydantic_ai.messages import (
|
|
20
21
|
AudioUrl,
|
|
@@ -230,10 +231,7 @@ class BedrockConverseModel(Model):
|
|
|
230
231
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
231
232
|
|
|
232
233
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
|
|
233
|
-
|
|
234
|
-
if model_request_parameters.output_tools:
|
|
235
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
236
|
-
return tools
|
|
234
|
+
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
237
235
|
|
|
238
236
|
@staticmethod
|
|
239
237
|
def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef:
|
|
@@ -269,10 +267,15 @@ class BedrockConverseModel(Model):
|
|
|
269
267
|
messages: list[ModelMessage],
|
|
270
268
|
model_settings: ModelSettings | None,
|
|
271
269
|
model_request_parameters: ModelRequestParameters,
|
|
270
|
+
run_context: RunContext[Any] | None = None,
|
|
272
271
|
) -> AsyncIterator[StreamedResponse]:
|
|
273
272
|
settings = cast(BedrockModelSettings, model_settings or {})
|
|
274
273
|
response = await self._messages_create(messages, True, settings, model_request_parameters)
|
|
275
|
-
yield BedrockStreamedResponse(
|
|
274
|
+
yield BedrockStreamedResponse(
|
|
275
|
+
model_request_parameters=model_request_parameters,
|
|
276
|
+
_model_name=self.model_name,
|
|
277
|
+
_event_stream=response,
|
|
278
|
+
)
|
|
276
279
|
|
|
277
280
|
async def _process_response(self, response: ConverseResponseTypeDef) -> ModelResponse:
|
|
278
281
|
items: list[ModelResponsePart] = []
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -248,10 +248,7 @@ class CohereModel(Model):
|
|
|
248
248
|
return cohere_messages
|
|
249
249
|
|
|
250
250
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolV2]:
|
|
251
|
-
|
|
252
|
-
if model_request_parameters.output_tools:
|
|
253
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
254
|
-
return tools
|
|
251
|
+
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
255
252
|
|
|
256
253
|
@staticmethod
|
|
257
254
|
def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
|
pydantic_ai/models/fallback.py
CHANGED
|
@@ -3,13 +3,15 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
from collections.abc import AsyncIterator
|
|
4
4
|
from contextlib import AsyncExitStack, asynccontextmanager, suppress
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
|
-
from typing import TYPE_CHECKING, Callable
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
7
7
|
|
|
8
8
|
from opentelemetry.trace import get_current_span
|
|
9
9
|
|
|
10
|
+
from pydantic_ai._run_context import RunContext
|
|
10
11
|
from pydantic_ai.models.instrumented import InstrumentedModel
|
|
11
12
|
|
|
12
13
|
from ..exceptions import FallbackExceptionGroup, ModelHTTPError
|
|
14
|
+
from ..settings import merge_model_settings
|
|
13
15
|
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
|
|
14
16
|
|
|
15
17
|
if TYPE_CHECKING:
|
|
@@ -64,8 +66,9 @@ class FallbackModel(Model):
|
|
|
64
66
|
|
|
65
67
|
for model in self.models:
|
|
66
68
|
customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
|
|
69
|
+
merged_settings = merge_model_settings(model.settings, model_settings)
|
|
67
70
|
try:
|
|
68
|
-
response = await model.request(messages,
|
|
71
|
+
response = await model.request(messages, merged_settings, customized_model_request_parameters)
|
|
69
72
|
except Exception as exc:
|
|
70
73
|
if self._fallback_on(exc):
|
|
71
74
|
exceptions.append(exc)
|
|
@@ -83,16 +86,20 @@ class FallbackModel(Model):
|
|
|
83
86
|
messages: list[ModelMessage],
|
|
84
87
|
model_settings: ModelSettings | None,
|
|
85
88
|
model_request_parameters: ModelRequestParameters,
|
|
89
|
+
run_context: RunContext[Any] | None = None,
|
|
86
90
|
) -> AsyncIterator[StreamedResponse]:
|
|
87
91
|
"""Try each model in sequence until one succeeds."""
|
|
88
92
|
exceptions: list[Exception] = []
|
|
89
93
|
|
|
90
94
|
for model in self.models:
|
|
91
95
|
customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
|
|
96
|
+
merged_settings = merge_model_settings(model.settings, model_settings)
|
|
92
97
|
async with AsyncExitStack() as stack:
|
|
93
98
|
try:
|
|
94
99
|
response = await stack.enter_async_context(
|
|
95
|
-
model.request_stream(
|
|
100
|
+
model.request_stream(
|
|
101
|
+
messages, merged_settings, customized_model_request_parameters, run_context
|
|
102
|
+
)
|
|
96
103
|
)
|
|
97
104
|
except Exception as exc:
|
|
98
105
|
if self._fallback_on(exc):
|
pydantic_ai/models/function.py
CHANGED
|
@@ -7,13 +7,12 @@ from contextlib import asynccontextmanager
|
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from itertools import chain
|
|
10
|
-
from typing import Callable, Union
|
|
10
|
+
from typing import Any, Callable, Union
|
|
11
11
|
|
|
12
12
|
from typing_extensions import TypeAlias, assert_never, overload
|
|
13
13
|
|
|
14
|
-
from pydantic_ai.profiles import ModelProfileSpec
|
|
15
|
-
|
|
16
14
|
from .. import _utils, usage
|
|
15
|
+
from .._run_context import RunContext
|
|
17
16
|
from .._utils import PeekableAsyncStream
|
|
18
17
|
from ..messages import (
|
|
19
18
|
BinaryContent,
|
|
@@ -32,6 +31,7 @@ from ..messages import (
|
|
|
32
31
|
UserContent,
|
|
33
32
|
UserPromptPart,
|
|
34
33
|
)
|
|
34
|
+
from ..profiles import ModelProfileSpec
|
|
35
35
|
from ..settings import ModelSettings
|
|
36
36
|
from ..tools import ToolDefinition
|
|
37
37
|
from . import Model, ModelRequestParameters, StreamedResponse
|
|
@@ -147,6 +147,7 @@ class FunctionModel(Model):
|
|
|
147
147
|
messages: list[ModelMessage],
|
|
148
148
|
model_settings: ModelSettings | None,
|
|
149
149
|
model_request_parameters: ModelRequestParameters,
|
|
150
|
+
run_context: RunContext[Any] | None = None,
|
|
150
151
|
) -> AsyncIterator[StreamedResponse]:
|
|
151
152
|
agent_info = AgentInfo(
|
|
152
153
|
model_request_parameters.function_tools,
|
|
@@ -165,7 +166,11 @@ class FunctionModel(Model):
|
|
|
165
166
|
if isinstance(first, _utils.Unset):
|
|
166
167
|
raise ValueError('Stream function must return at least one item')
|
|
167
168
|
|
|
168
|
-
yield FunctionStreamedResponse(
|
|
169
|
+
yield FunctionStreamedResponse(
|
|
170
|
+
model_request_parameters=model_request_parameters,
|
|
171
|
+
_model_name=self._model_name,
|
|
172
|
+
_iter=response_stream,
|
|
173
|
+
)
|
|
169
174
|
|
|
170
175
|
@property
|
|
171
176
|
def model_name(self) -> str:
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -13,10 +13,9 @@ import pydantic
|
|
|
13
13
|
from httpx import USE_CLIENT_DEFAULT, Response as HTTPResponse
|
|
14
14
|
from typing_extensions import NotRequired, TypedDict, assert_never, deprecated
|
|
15
15
|
|
|
16
|
-
from pydantic_ai.providers import Provider, infer_provider
|
|
17
|
-
|
|
18
16
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
19
17
|
from .._output import OutputObjectDefinition
|
|
18
|
+
from .._run_context import RunContext
|
|
20
19
|
from ..exceptions import UserError
|
|
21
20
|
from ..messages import (
|
|
22
21
|
BinaryContent,
|
|
@@ -38,6 +37,7 @@ from ..messages import (
|
|
|
38
37
|
VideoUrl,
|
|
39
38
|
)
|
|
40
39
|
from ..profiles import ModelProfileSpec
|
|
40
|
+
from ..providers import Provider, infer_provider
|
|
41
41
|
from ..settings import ModelSettings
|
|
42
42
|
from ..tools import ToolDefinition
|
|
43
43
|
from . import (
|
|
@@ -167,12 +167,13 @@ class GeminiModel(Model):
|
|
|
167
167
|
messages: list[ModelMessage],
|
|
168
168
|
model_settings: ModelSettings | None,
|
|
169
169
|
model_request_parameters: ModelRequestParameters,
|
|
170
|
+
run_context: RunContext[Any] | None = None,
|
|
170
171
|
) -> AsyncIterator[StreamedResponse]:
|
|
171
172
|
check_allow_model_requests()
|
|
172
173
|
async with self._make_request(
|
|
173
174
|
messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
|
|
174
175
|
) as http_response:
|
|
175
|
-
yield await self._process_streamed_response(http_response)
|
|
176
|
+
yield await self._process_streamed_response(http_response, model_request_parameters)
|
|
176
177
|
|
|
177
178
|
@property
|
|
178
179
|
def model_name(self) -> GeminiModelName:
|
|
@@ -185,9 +186,7 @@ class GeminiModel(Model):
|
|
|
185
186
|
return self._system
|
|
186
187
|
|
|
187
188
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
|
|
188
|
-
tools = [_function_from_abstract_tool(t) for t in model_request_parameters.
|
|
189
|
-
if model_request_parameters.output_tools:
|
|
190
|
-
tools += [_function_from_abstract_tool(t) for t in model_request_parameters.output_tools]
|
|
189
|
+
tools = [_function_from_abstract_tool(t) for t in model_request_parameters.tool_defs.values()]
|
|
191
190
|
return _GeminiTools(function_declarations=tools) if tools else None
|
|
192
191
|
|
|
193
192
|
def _get_tool_config(
|
|
@@ -288,7 +287,9 @@ class GeminiModel(Model):
|
|
|
288
287
|
vendor_details=vendor_details,
|
|
289
288
|
)
|
|
290
289
|
|
|
291
|
-
async def _process_streamed_response(
|
|
290
|
+
async def _process_streamed_response(
|
|
291
|
+
self, http_response: HTTPResponse, model_request_parameters: ModelRequestParameters
|
|
292
|
+
) -> StreamedResponse:
|
|
292
293
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
293
294
|
aiter_bytes = http_response.aiter_bytes()
|
|
294
295
|
start_response: _GeminiResponse | None = None
|
|
@@ -309,7 +310,12 @@ class GeminiModel(Model):
|
|
|
309
310
|
if start_response is None:
|
|
310
311
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
311
312
|
|
|
312
|
-
return GeminiStreamedResponse(
|
|
313
|
+
return GeminiStreamedResponse(
|
|
314
|
+
model_request_parameters=model_request_parameters,
|
|
315
|
+
_model_name=self._model_name,
|
|
316
|
+
_content=content,
|
|
317
|
+
_stream=aiter_bytes,
|
|
318
|
+
)
|
|
313
319
|
|
|
314
320
|
async def _message_to_gemini_content(
|
|
315
321
|
self, messages: list[ModelMessage]
|
|
@@ -872,7 +878,7 @@ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
|
|
|
872
878
|
metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details)
|
|
873
879
|
suffix = key.removesuffix('_details')
|
|
874
880
|
for detail in metadata_details:
|
|
875
|
-
details[f'{detail["modality"].lower()}_{suffix}'] = detail
|
|
881
|
+
details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
|
|
876
882
|
|
|
877
883
|
return usage.Usage(
|
|
878
884
|
request_tokens=metadata.get('prompt_token_count', 0),
|
pydantic_ai/models/google.py
CHANGED
|
@@ -8,11 +8,11 @@ from datetime import datetime
|
|
|
8
8
|
from typing import Any, Literal, Union, cast, overload
|
|
9
9
|
from uuid import uuid4
|
|
10
10
|
|
|
11
|
-
from google.genai.types import ExecutableCodeDict
|
|
12
11
|
from typing_extensions import assert_never
|
|
13
12
|
|
|
14
13
|
from .. import UnexpectedModelBehavior, _utils, usage
|
|
15
14
|
from .._output import OutputObjectDefinition
|
|
15
|
+
from .._run_context import RunContext
|
|
16
16
|
from ..builtin_tools import CodeExecutionTool, WebSearchTool
|
|
17
17
|
from ..exceptions import UserError
|
|
18
18
|
from ..messages import (
|
|
@@ -48,16 +48,19 @@ from . import (
|
|
|
48
48
|
)
|
|
49
49
|
|
|
50
50
|
try:
|
|
51
|
-
from google import
|
|
51
|
+
from google.genai import Client
|
|
52
52
|
from google.genai.types import (
|
|
53
53
|
ContentDict,
|
|
54
54
|
ContentUnionDict,
|
|
55
|
+
CountTokensConfigDict,
|
|
56
|
+
ExecutableCodeDict,
|
|
55
57
|
FunctionCallDict,
|
|
56
58
|
FunctionCallingConfigDict,
|
|
57
59
|
FunctionCallingConfigMode,
|
|
58
60
|
FunctionDeclarationDict,
|
|
59
61
|
GenerateContentConfigDict,
|
|
60
62
|
GenerateContentResponse,
|
|
63
|
+
GenerationConfigDict,
|
|
61
64
|
GoogleSearchDict,
|
|
62
65
|
HttpOptionsDict,
|
|
63
66
|
MediaResolution,
|
|
@@ -136,10 +139,10 @@ class GoogleModel(Model):
|
|
|
136
139
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
137
140
|
"""
|
|
138
141
|
|
|
139
|
-
client:
|
|
142
|
+
client: Client = field(repr=False)
|
|
140
143
|
|
|
141
144
|
_model_name: GoogleModelName = field(repr=False)
|
|
142
|
-
_provider: Provider[
|
|
145
|
+
_provider: Provider[Client] = field(repr=False)
|
|
143
146
|
_url: str | None = field(repr=False)
|
|
144
147
|
_system: str = field(default='google', repr=False)
|
|
145
148
|
|
|
@@ -147,7 +150,7 @@ class GoogleModel(Model):
|
|
|
147
150
|
self,
|
|
148
151
|
model_name: GoogleModelName,
|
|
149
152
|
*,
|
|
150
|
-
provider: Literal['google-gla', 'google-vertex'] | Provider[
|
|
153
|
+
provider: Literal['google-gla', 'google-vertex'] | Provider[Client] = 'google-gla',
|
|
151
154
|
profile: ModelProfileSpec | None = None,
|
|
152
155
|
settings: ModelSettings | None = None,
|
|
153
156
|
):
|
|
@@ -187,17 +190,71 @@ class GoogleModel(Model):
|
|
|
187
190
|
response = await self._generate_content(messages, False, model_settings, model_request_parameters)
|
|
188
191
|
return self._process_response(response)
|
|
189
192
|
|
|
193
|
+
async def count_tokens(
|
|
194
|
+
self,
|
|
195
|
+
messages: list[ModelMessage],
|
|
196
|
+
model_settings: ModelSettings | None,
|
|
197
|
+
model_request_parameters: ModelRequestParameters,
|
|
198
|
+
) -> usage.Usage:
|
|
199
|
+
check_allow_model_requests()
|
|
200
|
+
model_settings = cast(GoogleModelSettings, model_settings or {})
|
|
201
|
+
contents, generation_config = await self._build_content_and_config(
|
|
202
|
+
messages, model_settings, model_request_parameters
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Annoyingly, the type of `GenerateContentConfigDict.get` is "partially `Unknown`" because `response_schema` includes `typing._UnionGenericAlias`,
|
|
206
|
+
# so without this we'd need `pyright: ignore[reportUnknownMemberType]` on every line and wouldn't get type checking anyway.
|
|
207
|
+
generation_config = cast(dict[str, Any], generation_config)
|
|
208
|
+
|
|
209
|
+
config = CountTokensConfigDict(
|
|
210
|
+
http_options=generation_config.get('http_options'),
|
|
211
|
+
)
|
|
212
|
+
if self.system != 'google-gla':
|
|
213
|
+
# The fields are not supported by the Gemini API per https://github.com/googleapis/python-genai/blob/7e4ec284dc6e521949626f3ed54028163ef9121d/google/genai/models.py#L1195-L1214
|
|
214
|
+
config.update(
|
|
215
|
+
system_instruction=generation_config.get('system_instruction'),
|
|
216
|
+
tools=cast(list[ToolDict], generation_config.get('tools')),
|
|
217
|
+
# Annoyingly, GenerationConfigDict has fewer fields than GenerateContentConfigDict, and no extra fields are allowed.
|
|
218
|
+
generation_config=GenerationConfigDict(
|
|
219
|
+
temperature=generation_config.get('temperature'),
|
|
220
|
+
top_p=generation_config.get('top_p'),
|
|
221
|
+
max_output_tokens=generation_config.get('max_output_tokens'),
|
|
222
|
+
stop_sequences=generation_config.get('stop_sequences'),
|
|
223
|
+
presence_penalty=generation_config.get('presence_penalty'),
|
|
224
|
+
frequency_penalty=generation_config.get('frequency_penalty'),
|
|
225
|
+
thinking_config=generation_config.get('thinking_config'),
|
|
226
|
+
media_resolution=generation_config.get('media_resolution'),
|
|
227
|
+
response_mime_type=generation_config.get('response_mime_type'),
|
|
228
|
+
response_schema=generation_config.get('response_schema'),
|
|
229
|
+
),
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
response = await self.client.aio.models.count_tokens(
|
|
233
|
+
model=self._model_name,
|
|
234
|
+
contents=contents,
|
|
235
|
+
config=config,
|
|
236
|
+
)
|
|
237
|
+
if response.total_tokens is None:
|
|
238
|
+
raise UnexpectedModelBehavior( # pragma: no cover
|
|
239
|
+
'Total tokens missing from Gemini response', str(response)
|
|
240
|
+
)
|
|
241
|
+
return usage.Usage(
|
|
242
|
+
request_tokens=response.total_tokens,
|
|
243
|
+
total_tokens=response.total_tokens,
|
|
244
|
+
)
|
|
245
|
+
|
|
190
246
|
@asynccontextmanager
|
|
191
247
|
async def request_stream(
|
|
192
248
|
self,
|
|
193
249
|
messages: list[ModelMessage],
|
|
194
250
|
model_settings: ModelSettings | None,
|
|
195
251
|
model_request_parameters: ModelRequestParameters,
|
|
252
|
+
run_context: RunContext[Any] | None = None,
|
|
196
253
|
) -> AsyncIterator[StreamedResponse]:
|
|
197
254
|
check_allow_model_requests()
|
|
198
255
|
model_settings = cast(GoogleModelSettings, model_settings or {})
|
|
199
256
|
response = await self._generate_content(messages, True, model_settings, model_request_parameters)
|
|
200
|
-
yield await self._process_streamed_response(response) # type: ignore
|
|
257
|
+
yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
|
|
201
258
|
|
|
202
259
|
@property
|
|
203
260
|
def model_name(self) -> GoogleModelName:
|
|
@@ -212,18 +269,17 @@ class GoogleModel(Model):
|
|
|
212
269
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
|
|
213
270
|
tools: list[ToolDict] = [
|
|
214
271
|
ToolDict(function_declarations=[_function_declaration_from_tool(t)])
|
|
215
|
-
for t in model_request_parameters.
|
|
272
|
+
for t in model_request_parameters.tool_defs.values()
|
|
216
273
|
]
|
|
217
|
-
if model_request_parameters.output_tools:
|
|
218
|
-
tools += [
|
|
219
|
-
ToolDict(function_declarations=[_function_declaration_from_tool(t)])
|
|
220
|
-
for t in model_request_parameters.output_tools
|
|
221
|
-
]
|
|
222
274
|
for tool in model_request_parameters.builtin_tools:
|
|
223
275
|
if isinstance(tool, WebSearchTool):
|
|
224
276
|
tools.append(ToolDict(google_search=GoogleSearchDict()))
|
|
225
277
|
elif isinstance(tool, CodeExecutionTool): # pragma: no branch
|
|
226
278
|
tools.append(ToolDict(code_execution=ToolCodeExecutionDict()))
|
|
279
|
+
else: # pragma: no cover
|
|
280
|
+
raise UserError(
|
|
281
|
+
f'`{tool.__class__.__name__}` is not supported by `GoogleModel`. If it should be, please file an issue.'
|
|
282
|
+
)
|
|
227
283
|
return tools or None
|
|
228
284
|
|
|
229
285
|
def _get_tool_config(
|
|
@@ -264,16 +320,23 @@ class GoogleModel(Model):
|
|
|
264
320
|
model_settings: GoogleModelSettings,
|
|
265
321
|
model_request_parameters: ModelRequestParameters,
|
|
266
322
|
) -> GenerateContentResponse | Awaitable[AsyncIterator[GenerateContentResponse]]:
|
|
267
|
-
|
|
323
|
+
contents, config = await self._build_content_and_config(messages, model_settings, model_request_parameters)
|
|
324
|
+
func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content
|
|
325
|
+
return await func(model=self._model_name, contents=contents, config=config) # type: ignore
|
|
268
326
|
|
|
327
|
+
async def _build_content_and_config(
|
|
328
|
+
self,
|
|
329
|
+
messages: list[ModelMessage],
|
|
330
|
+
model_settings: GoogleModelSettings,
|
|
331
|
+
model_request_parameters: ModelRequestParameters,
|
|
332
|
+
) -> tuple[list[ContentUnionDict], GenerateContentConfigDict]:
|
|
333
|
+
tools = self._get_tools(model_request_parameters)
|
|
269
334
|
response_mime_type = None
|
|
270
335
|
response_schema = None
|
|
271
336
|
if model_request_parameters.output_mode == 'native':
|
|
272
337
|
if tools:
|
|
273
338
|
raise UserError('Gemini does not support structured output and tools at the same time.')
|
|
274
|
-
|
|
275
339
|
response_mime_type = 'application/json'
|
|
276
|
-
|
|
277
340
|
output_object = model_request_parameters.output_object
|
|
278
341
|
assert output_object is not None
|
|
279
342
|
response_schema = self._map_response_schema(output_object)
|
|
@@ -310,9 +373,7 @@ class GoogleModel(Model):
|
|
|
310
373
|
response_mime_type=response_mime_type,
|
|
311
374
|
response_schema=response_schema,
|
|
312
375
|
)
|
|
313
|
-
|
|
314
|
-
func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content
|
|
315
|
-
return await func(model=self._model_name, contents=contents, config=config) # type: ignore
|
|
376
|
+
return contents, config
|
|
316
377
|
|
|
317
378
|
def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
|
|
318
379
|
if not response.candidates or len(response.candidates) != 1:
|
|
@@ -336,7 +397,9 @@ class GoogleModel(Model):
|
|
|
336
397
|
parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details
|
|
337
398
|
)
|
|
338
399
|
|
|
339
|
-
async def _process_streamed_response(
|
|
400
|
+
async def _process_streamed_response(
|
|
401
|
+
self, response: AsyncIterator[GenerateContentResponse], model_request_parameters: ModelRequestParameters
|
|
402
|
+
) -> StreamedResponse:
|
|
340
403
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
341
404
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
342
405
|
first_chunk = await peekable_response.peek()
|
|
@@ -344,6 +407,7 @@ class GoogleModel(Model):
|
|
|
344
407
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') # pragma: no cover
|
|
345
408
|
|
|
346
409
|
return GeminiStreamedResponse(
|
|
410
|
+
model_request_parameters=model_request_parameters,
|
|
347
411
|
_model_name=self._model_name,
|
|
348
412
|
_response=peekable_response,
|
|
349
413
|
_timestamp=first_chunk.create_time or _utils.now_utc(),
|
|
@@ -603,7 +667,7 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage:
|
|
|
603
667
|
if key.endswith('_details') and metadata_details:
|
|
604
668
|
suffix = key.removesuffix('_details')
|
|
605
669
|
for detail in metadata_details:
|
|
606
|
-
details[f'{detail["modality"].lower()}_{suffix}'] = detail
|
|
670
|
+
details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
|
|
607
671
|
|
|
608
672
|
return usage.Usage(
|
|
609
673
|
request_tokens=metadata.get('prompt_token_count', 0),
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -5,17 +5,16 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import Literal, Union, cast, overload
|
|
8
|
+
from typing import Any, Literal, Union, cast, overload
|
|
9
9
|
|
|
10
10
|
from typing_extensions import assert_never
|
|
11
11
|
|
|
12
|
-
from pydantic_ai._thinking_part import split_content_into_text_and_thinking
|
|
13
|
-
from pydantic_ai.exceptions import UserError
|
|
14
|
-
from pydantic_ai.profiles.groq import GroqModelProfile
|
|
15
|
-
|
|
16
12
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
13
|
+
from .._run_context import RunContext
|
|
14
|
+
from .._thinking_part import split_content_into_text_and_thinking
|
|
17
15
|
from .._utils import generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id, number_to_datetime
|
|
18
|
-
from ..builtin_tools import
|
|
16
|
+
from ..builtin_tools import WebSearchTool
|
|
17
|
+
from ..exceptions import UserError
|
|
19
18
|
from ..messages import (
|
|
20
19
|
BinaryContent,
|
|
21
20
|
BuiltinToolCallPart,
|
|
@@ -36,6 +35,7 @@ from ..messages import (
|
|
|
36
35
|
UserPromptPart,
|
|
37
36
|
)
|
|
38
37
|
from ..profiles import ModelProfile, ModelProfileSpec
|
|
38
|
+
from ..profiles.groq import GroqModelProfile
|
|
39
39
|
from ..providers import Provider, infer_provider
|
|
40
40
|
from ..settings import ModelSettings
|
|
41
41
|
from ..tools import ToolDefinition
|
|
@@ -171,13 +171,14 @@ class GroqModel(Model):
|
|
|
171
171
|
messages: list[ModelMessage],
|
|
172
172
|
model_settings: ModelSettings | None,
|
|
173
173
|
model_request_parameters: ModelRequestParameters,
|
|
174
|
+
run_context: RunContext[Any] | None = None,
|
|
174
175
|
) -> AsyncIterator[StreamedResponse]:
|
|
175
176
|
check_allow_model_requests()
|
|
176
177
|
response = await self._completions_create(
|
|
177
178
|
messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters
|
|
178
179
|
)
|
|
179
180
|
async with response:
|
|
180
|
-
yield await self._process_streamed_response(response)
|
|
181
|
+
yield await self._process_streamed_response(response, model_request_parameters)
|
|
181
182
|
|
|
182
183
|
@property
|
|
183
184
|
def model_name(self) -> GroqModelName:
|
|
@@ -287,7 +288,9 @@ class GroqModel(Model):
|
|
|
287
288
|
items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
|
|
288
289
|
)
|
|
289
290
|
|
|
290
|
-
async def _process_streamed_response(
|
|
291
|
+
async def _process_streamed_response(
|
|
292
|
+
self, response: AsyncStream[chat.ChatCompletionChunk], model_request_parameters: ModelRequestParameters
|
|
293
|
+
) -> GroqStreamedResponse:
|
|
291
294
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
292
295
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
293
296
|
first_chunk = await peekable_response.peek()
|
|
@@ -297,6 +300,7 @@ class GroqModel(Model):
|
|
|
297
300
|
)
|
|
298
301
|
|
|
299
302
|
return GroqStreamedResponse(
|
|
303
|
+
model_request_parameters=model_request_parameters,
|
|
300
304
|
_response=peekable_response,
|
|
301
305
|
_model_name=self._model_name,
|
|
302
306
|
_model_profile=self.profile,
|
|
@@ -304,10 +308,7 @@ class GroqModel(Model):
|
|
|
304
308
|
)
|
|
305
309
|
|
|
306
310
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
307
|
-
|
|
308
|
-
if model_request_parameters.output_tools:
|
|
309
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
310
|
-
return tools
|
|
311
|
+
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
311
312
|
|
|
312
313
|
def _get_builtin_tools(
|
|
313
314
|
self, model_request_parameters: ModelRequestParameters
|
|
@@ -317,8 +318,10 @@ class GroqModel(Model):
|
|
|
317
318
|
if isinstance(tool, WebSearchTool):
|
|
318
319
|
if not GroqModelProfile.from_profile(self.profile).groq_always_has_web_search_builtin_tool:
|
|
319
320
|
raise UserError('`WebSearchTool` is not supported by Groq') # pragma: no cover
|
|
320
|
-
|
|
321
|
-
raise UserError(
|
|
321
|
+
else:
|
|
322
|
+
raise UserError(
|
|
323
|
+
f'`{tool.__class__.__name__}` is not supported by `GroqModel`. If it should be, please file an issue.'
|
|
324
|
+
)
|
|
322
325
|
return tools
|
|
323
326
|
|
|
324
327
|
def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
|