pydantic-ai-slim 0.6.1__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +5 -0
- pydantic_ai/_a2a.py +6 -4
- pydantic_ai/_agent_graph.py +32 -32
- pydantic_ai/_cli.py +3 -3
- pydantic_ai/_output.py +8 -0
- pydantic_ai/_tool_manager.py +3 -0
- pydantic_ai/_utils.py +7 -1
- pydantic_ai/ag_ui.py +25 -14
- pydantic_ai/{agent.py → agent/__init__.py} +217 -1026
- pydantic_ai/agent/abstract.py +942 -0
- pydantic_ai/agent/wrapper.py +227 -0
- pydantic_ai/builtin_tools.py +105 -0
- pydantic_ai/direct.py +9 -9
- pydantic_ai/durable_exec/__init__.py +0 -0
- pydantic_ai/durable_exec/temporal/__init__.py +83 -0
- pydantic_ai/durable_exec/temporal/_agent.py +699 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +92 -0
- pydantic_ai/durable_exec/temporal/_logfire.py +48 -0
- pydantic_ai/durable_exec/temporal/_mcp_server.py +145 -0
- pydantic_ai/durable_exec/temporal/_model.py +168 -0
- pydantic_ai/durable_exec/temporal/_run_context.py +50 -0
- pydantic_ai/durable_exec/temporal/_toolset.py +77 -0
- pydantic_ai/ext/aci.py +10 -9
- pydantic_ai/ext/langchain.py +4 -2
- pydantic_ai/mcp.py +203 -75
- pydantic_ai/messages.py +75 -13
- pydantic_ai/models/__init__.py +66 -8
- pydantic_ai/models/anthropic.py +135 -18
- pydantic_ai/models/bedrock.py +16 -5
- pydantic_ai/models/cohere.py +11 -4
- pydantic_ai/models/fallback.py +4 -2
- pydantic_ai/models/function.py +18 -4
- pydantic_ai/models/gemini.py +20 -9
- pydantic_ai/models/google.py +53 -15
- pydantic_ai/models/groq.py +47 -11
- pydantic_ai/models/huggingface.py +26 -11
- pydantic_ai/models/instrumented.py +3 -1
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +27 -17
- pydantic_ai/models/openai.py +97 -33
- pydantic_ai/models/test.py +12 -0
- pydantic_ai/models/wrapper.py +6 -2
- pydantic_ai/profiles/groq.py +23 -0
- pydantic_ai/profiles/openai.py +1 -1
- pydantic_ai/providers/google.py +7 -7
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/result.py +21 -55
- pydantic_ai/run.py +357 -0
- pydantic_ai/tools.py +0 -1
- pydantic_ai/toolsets/__init__.py +2 -0
- pydantic_ai/toolsets/_dynamic.py +87 -0
- pydantic_ai/toolsets/abstract.py +23 -3
- pydantic_ai/toolsets/combined.py +19 -4
- pydantic_ai/toolsets/deferred.py +10 -2
- pydantic_ai/toolsets/function.py +23 -8
- pydantic_ai/toolsets/prefixed.py +4 -0
- pydantic_ai/toolsets/wrapper.py +14 -1
- {pydantic_ai_slim-0.6.1.dist-info → pydantic_ai_slim-0.7.0.dist-info}/METADATA +7 -5
- pydantic_ai_slim-0.7.0.dist-info/RECORD +115 -0
- pydantic_ai_slim-0.6.1.dist-info/RECORD +0 -100
- {pydantic_ai_slim-0.6.1.dist-info → pydantic_ai_slim-0.7.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.6.1.dist-info → pydantic_ai_slim-0.7.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.6.1.dist-info → pydantic_ai_slim-0.7.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/gemini.py
CHANGED
|
@@ -13,13 +13,14 @@ import pydantic
|
|
|
13
13
|
from httpx import USE_CLIENT_DEFAULT, Response as HTTPResponse
|
|
14
14
|
from typing_extensions import NotRequired, TypedDict, assert_never, deprecated
|
|
15
15
|
|
|
16
|
-
from pydantic_ai.providers import Provider, infer_provider
|
|
17
|
-
|
|
18
16
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
19
17
|
from .._output import OutputObjectDefinition
|
|
18
|
+
from .._run_context import RunContext
|
|
20
19
|
from ..exceptions import UserError
|
|
21
20
|
from ..messages import (
|
|
22
21
|
BinaryContent,
|
|
22
|
+
BuiltinToolCallPart,
|
|
23
|
+
BuiltinToolReturnPart,
|
|
23
24
|
FileUrl,
|
|
24
25
|
ModelMessage,
|
|
25
26
|
ModelRequest,
|
|
@@ -36,6 +37,7 @@ from ..messages import (
|
|
|
36
37
|
VideoUrl,
|
|
37
38
|
)
|
|
38
39
|
from ..profiles import ModelProfileSpec
|
|
40
|
+
from ..providers import Provider, infer_provider
|
|
39
41
|
from ..settings import ModelSettings
|
|
40
42
|
from ..tools import ToolDefinition
|
|
41
43
|
from . import (
|
|
@@ -165,12 +167,13 @@ class GeminiModel(Model):
|
|
|
165
167
|
messages: list[ModelMessage],
|
|
166
168
|
model_settings: ModelSettings | None,
|
|
167
169
|
model_request_parameters: ModelRequestParameters,
|
|
170
|
+
run_context: RunContext[Any] | None = None,
|
|
168
171
|
) -> AsyncIterator[StreamedResponse]:
|
|
169
172
|
check_allow_model_requests()
|
|
170
173
|
async with self._make_request(
|
|
171
174
|
messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
|
|
172
175
|
) as http_response:
|
|
173
|
-
yield await self._process_streamed_response(http_response)
|
|
176
|
+
yield await self._process_streamed_response(http_response, model_request_parameters)
|
|
174
177
|
|
|
175
178
|
@property
|
|
176
179
|
def model_name(self) -> GeminiModelName:
|
|
@@ -183,9 +186,7 @@ class GeminiModel(Model):
|
|
|
183
186
|
return self._system
|
|
184
187
|
|
|
185
188
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
|
|
186
|
-
tools = [_function_from_abstract_tool(t) for t in model_request_parameters.
|
|
187
|
-
if model_request_parameters.output_tools:
|
|
188
|
-
tools += [_function_from_abstract_tool(t) for t in model_request_parameters.output_tools]
|
|
189
|
+
tools = [_function_from_abstract_tool(t) for t in model_request_parameters.tool_defs.values()]
|
|
189
190
|
return _GeminiTools(function_declarations=tools) if tools else None
|
|
190
191
|
|
|
191
192
|
def _get_tool_config(
|
|
@@ -286,7 +287,9 @@ class GeminiModel(Model):
|
|
|
286
287
|
vendor_details=vendor_details,
|
|
287
288
|
)
|
|
288
289
|
|
|
289
|
-
async def _process_streamed_response(
|
|
290
|
+
async def _process_streamed_response(
|
|
291
|
+
self, http_response: HTTPResponse, model_request_parameters: ModelRequestParameters
|
|
292
|
+
) -> StreamedResponse:
|
|
290
293
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
291
294
|
aiter_bytes = http_response.aiter_bytes()
|
|
292
295
|
start_response: _GeminiResponse | None = None
|
|
@@ -307,7 +310,12 @@ class GeminiModel(Model):
|
|
|
307
310
|
if start_response is None:
|
|
308
311
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
309
312
|
|
|
310
|
-
return GeminiStreamedResponse(
|
|
313
|
+
return GeminiStreamedResponse(
|
|
314
|
+
model_request_parameters=model_request_parameters,
|
|
315
|
+
_model_name=self._model_name,
|
|
316
|
+
_content=content,
|
|
317
|
+
_stream=aiter_bytes,
|
|
318
|
+
)
|
|
311
319
|
|
|
312
320
|
async def _message_to_gemini_content(
|
|
313
321
|
self, messages: list[ModelMessage]
|
|
@@ -610,6 +618,9 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent:
|
|
|
610
618
|
elif isinstance(item, TextPart):
|
|
611
619
|
if item.content:
|
|
612
620
|
parts.append(_GeminiTextPart(text=item.content))
|
|
621
|
+
elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
|
|
622
|
+
# This is currently never returned from gemini
|
|
623
|
+
pass
|
|
613
624
|
else:
|
|
614
625
|
assert_never(item)
|
|
615
626
|
return _GeminiContent(role='model', parts=parts)
|
|
@@ -867,7 +878,7 @@ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
|
|
|
867
878
|
metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details)
|
|
868
879
|
suffix = key.removesuffix('_details')
|
|
869
880
|
for detail in metadata_details:
|
|
870
|
-
details[f'{detail["modality"].lower()}_{suffix}'] = detail
|
|
881
|
+
details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
|
|
871
882
|
|
|
872
883
|
return usage.Usage(
|
|
873
884
|
request_tokens=metadata.get('prompt_token_count', 0),
|
pydantic_ai/models/google.py
CHANGED
|
@@ -12,9 +12,13 @@ from typing_extensions import assert_never
|
|
|
12
12
|
|
|
13
13
|
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
14
|
from .._output import OutputObjectDefinition
|
|
15
|
+
from .._run_context import RunContext
|
|
16
|
+
from ..builtin_tools import CodeExecutionTool, WebSearchTool
|
|
15
17
|
from ..exceptions import UserError
|
|
16
18
|
from ..messages import (
|
|
17
19
|
BinaryContent,
|
|
20
|
+
BuiltinToolCallPart,
|
|
21
|
+
BuiltinToolReturnPart,
|
|
18
22
|
FileUrl,
|
|
19
23
|
ModelMessage,
|
|
20
24
|
ModelRequest,
|
|
@@ -44,22 +48,25 @@ from . import (
|
|
|
44
48
|
)
|
|
45
49
|
|
|
46
50
|
try:
|
|
47
|
-
from google import
|
|
51
|
+
from google.genai import Client
|
|
48
52
|
from google.genai.types import (
|
|
49
53
|
ContentDict,
|
|
50
54
|
ContentUnionDict,
|
|
55
|
+
ExecutableCodeDict,
|
|
51
56
|
FunctionCallDict,
|
|
52
57
|
FunctionCallingConfigDict,
|
|
53
58
|
FunctionCallingConfigMode,
|
|
54
59
|
FunctionDeclarationDict,
|
|
55
60
|
GenerateContentConfigDict,
|
|
56
61
|
GenerateContentResponse,
|
|
62
|
+
GoogleSearchDict,
|
|
57
63
|
HttpOptionsDict,
|
|
58
64
|
MediaResolution,
|
|
59
65
|
Part,
|
|
60
66
|
PartDict,
|
|
61
67
|
SafetySettingDict,
|
|
62
68
|
ThinkingConfigDict,
|
|
69
|
+
ToolCodeExecutionDict,
|
|
63
70
|
ToolConfigDict,
|
|
64
71
|
ToolDict,
|
|
65
72
|
ToolListUnionDict,
|
|
@@ -130,10 +137,10 @@ class GoogleModel(Model):
|
|
|
130
137
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
131
138
|
"""
|
|
132
139
|
|
|
133
|
-
client:
|
|
140
|
+
client: Client = field(repr=False)
|
|
134
141
|
|
|
135
142
|
_model_name: GoogleModelName = field(repr=False)
|
|
136
|
-
_provider: Provider[
|
|
143
|
+
_provider: Provider[Client] = field(repr=False)
|
|
137
144
|
_url: str | None = field(repr=False)
|
|
138
145
|
_system: str = field(default='google', repr=False)
|
|
139
146
|
|
|
@@ -141,7 +148,7 @@ class GoogleModel(Model):
|
|
|
141
148
|
self,
|
|
142
149
|
model_name: GoogleModelName,
|
|
143
150
|
*,
|
|
144
|
-
provider: Literal['google-gla', 'google-vertex'] | Provider[
|
|
151
|
+
provider: Literal['google-gla', 'google-vertex'] | Provider[Client] = 'google-gla',
|
|
145
152
|
profile: ModelProfileSpec | None = None,
|
|
146
153
|
settings: ModelSettings | None = None,
|
|
147
154
|
):
|
|
@@ -187,11 +194,12 @@ class GoogleModel(Model):
|
|
|
187
194
|
messages: list[ModelMessage],
|
|
188
195
|
model_settings: ModelSettings | None,
|
|
189
196
|
model_request_parameters: ModelRequestParameters,
|
|
197
|
+
run_context: RunContext[Any] | None = None,
|
|
190
198
|
) -> AsyncIterator[StreamedResponse]:
|
|
191
199
|
check_allow_model_requests()
|
|
192
200
|
model_settings = cast(GoogleModelSettings, model_settings or {})
|
|
193
201
|
response = await self._generate_content(messages, True, model_settings, model_request_parameters)
|
|
194
|
-
yield await self._process_streamed_response(response) # type: ignore
|
|
202
|
+
yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
|
|
195
203
|
|
|
196
204
|
@property
|
|
197
205
|
def model_name(self) -> GoogleModelName:
|
|
@@ -206,13 +214,17 @@ class GoogleModel(Model):
|
|
|
206
214
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
|
|
207
215
|
tools: list[ToolDict] = [
|
|
208
216
|
ToolDict(function_declarations=[_function_declaration_from_tool(t)])
|
|
209
|
-
for t in model_request_parameters.
|
|
217
|
+
for t in model_request_parameters.tool_defs.values()
|
|
210
218
|
]
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
ToolDict(
|
|
214
|
-
|
|
215
|
-
|
|
219
|
+
for tool in model_request_parameters.builtin_tools:
|
|
220
|
+
if isinstance(tool, WebSearchTool):
|
|
221
|
+
tools.append(ToolDict(google_search=GoogleSearchDict()))
|
|
222
|
+
elif isinstance(tool, CodeExecutionTool): # pragma: no branch
|
|
223
|
+
tools.append(ToolDict(code_execution=ToolCodeExecutionDict()))
|
|
224
|
+
else: # pragma: no cover
|
|
225
|
+
raise UserError(
|
|
226
|
+
f'`{tool.__class__.__name__}` is not supported by `GoogleModel`. If it should be, please file an issue.'
|
|
227
|
+
)
|
|
216
228
|
return tools or None
|
|
217
229
|
|
|
218
230
|
def _get_tool_config(
|
|
@@ -325,7 +337,9 @@ class GoogleModel(Model):
|
|
|
325
337
|
parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details
|
|
326
338
|
)
|
|
327
339
|
|
|
328
|
-
async def _process_streamed_response(
|
|
340
|
+
async def _process_streamed_response(
|
|
341
|
+
self, response: AsyncIterator[GenerateContentResponse], model_request_parameters: ModelRequestParameters
|
|
342
|
+
) -> StreamedResponse:
|
|
329
343
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
330
344
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
331
345
|
first_chunk = await peekable_response.peek()
|
|
@@ -333,6 +347,7 @@ class GoogleModel(Model):
|
|
|
333
347
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') # pragma: no cover
|
|
334
348
|
|
|
335
349
|
return GeminiStreamedResponse(
|
|
350
|
+
model_request_parameters=model_request_parameters,
|
|
336
351
|
_model_name=self._model_name,
|
|
337
352
|
_response=peekable_response,
|
|
338
353
|
_timestamp=first_chunk.create_time or _utils.now_utc(),
|
|
@@ -499,6 +514,14 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
|
|
|
499
514
|
# please open an issue. The below code is the code to send thinking to the provider.
|
|
500
515
|
# parts.append({'text': item.content, 'thought': True})
|
|
501
516
|
pass
|
|
517
|
+
elif isinstance(item, BuiltinToolCallPart):
|
|
518
|
+
if item.provider_name == 'google':
|
|
519
|
+
if item.tool_name == 'code_execution': # pragma: no branch
|
|
520
|
+
parts.append({'executable_code': cast(ExecutableCodeDict, item.args)})
|
|
521
|
+
elif isinstance(item, BuiltinToolReturnPart):
|
|
522
|
+
if item.provider_name == 'google':
|
|
523
|
+
if item.tool_name == 'code_execution': # pragma: no branch
|
|
524
|
+
parts.append({'code_execution_result': item.content})
|
|
502
525
|
else:
|
|
503
526
|
assert_never(item)
|
|
504
527
|
return ContentDict(role='model', parts=parts)
|
|
@@ -513,7 +536,22 @@ def _process_response_from_parts(
|
|
|
513
536
|
) -> ModelResponse:
|
|
514
537
|
items: list[ModelResponsePart] = []
|
|
515
538
|
for part in parts:
|
|
516
|
-
if part.
|
|
539
|
+
if part.executable_code is not None:
|
|
540
|
+
items.append(
|
|
541
|
+
BuiltinToolCallPart(
|
|
542
|
+
provider_name='google', args=part.executable_code.model_dump(), tool_name='code_execution'
|
|
543
|
+
)
|
|
544
|
+
)
|
|
545
|
+
elif part.code_execution_result is not None:
|
|
546
|
+
items.append(
|
|
547
|
+
BuiltinToolReturnPart(
|
|
548
|
+
provider_name='google',
|
|
549
|
+
tool_name='code_execution',
|
|
550
|
+
content=part.code_execution_result,
|
|
551
|
+
tool_call_id='not_provided',
|
|
552
|
+
)
|
|
553
|
+
)
|
|
554
|
+
elif part.text is not None:
|
|
517
555
|
if part.thought:
|
|
518
556
|
items.append(ThinkingPart(content=part.text))
|
|
519
557
|
else:
|
|
@@ -563,13 +601,13 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage:
|
|
|
563
601
|
details['thoughts_tokens'] = thoughts_token_count
|
|
564
602
|
|
|
565
603
|
if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
|
|
566
|
-
details['tool_use_prompt_tokens'] = tool_use_prompt_token_count
|
|
604
|
+
details['tool_use_prompt_tokens'] = tool_use_prompt_token_count
|
|
567
605
|
|
|
568
606
|
for key, metadata_details in metadata.items():
|
|
569
607
|
if key.endswith('_details') and metadata_details:
|
|
570
608
|
suffix = key.removesuffix('_details')
|
|
571
609
|
for detail in metadata_details:
|
|
572
|
-
details[f'{detail["modality"].lower()}_{suffix}'] = detail
|
|
610
|
+
details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
|
|
573
611
|
|
|
574
612
|
return usage.Usage(
|
|
575
613
|
request_tokens=metadata.get('prompt_token_count', 0),
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -5,16 +5,20 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import Literal, Union, cast, overload
|
|
8
|
+
from typing import Any, Literal, Union, cast, overload
|
|
9
9
|
|
|
10
10
|
from typing_extensions import assert_never
|
|
11
11
|
|
|
12
|
-
from pydantic_ai._thinking_part import split_content_into_text_and_thinking
|
|
13
|
-
|
|
14
12
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
15
|
-
from ..
|
|
13
|
+
from .._run_context import RunContext
|
|
14
|
+
from .._thinking_part import split_content_into_text_and_thinking
|
|
15
|
+
from .._utils import generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id, number_to_datetime
|
|
16
|
+
from ..builtin_tools import WebSearchTool
|
|
17
|
+
from ..exceptions import UserError
|
|
16
18
|
from ..messages import (
|
|
17
19
|
BinaryContent,
|
|
20
|
+
BuiltinToolCallPart,
|
|
21
|
+
BuiltinToolReturnPart,
|
|
18
22
|
DocumentUrl,
|
|
19
23
|
ImageUrl,
|
|
20
24
|
ModelMessage,
|
|
@@ -31,6 +35,7 @@ from ..messages import (
|
|
|
31
35
|
UserPromptPart,
|
|
32
36
|
)
|
|
33
37
|
from ..profiles import ModelProfile, ModelProfileSpec
|
|
38
|
+
from ..profiles.groq import GroqModelProfile
|
|
34
39
|
from ..providers import Provider, infer_provider
|
|
35
40
|
from ..settings import ModelSettings
|
|
36
41
|
from ..tools import ToolDefinition
|
|
@@ -166,13 +171,14 @@ class GroqModel(Model):
|
|
|
166
171
|
messages: list[ModelMessage],
|
|
167
172
|
model_settings: ModelSettings | None,
|
|
168
173
|
model_request_parameters: ModelRequestParameters,
|
|
174
|
+
run_context: RunContext[Any] | None = None,
|
|
169
175
|
) -> AsyncIterator[StreamedResponse]:
|
|
170
176
|
check_allow_model_requests()
|
|
171
177
|
response = await self._completions_create(
|
|
172
178
|
messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters
|
|
173
179
|
)
|
|
174
180
|
async with response:
|
|
175
|
-
yield await self._process_streamed_response(response)
|
|
181
|
+
yield await self._process_streamed_response(response, model_request_parameters)
|
|
176
182
|
|
|
177
183
|
@property
|
|
178
184
|
def model_name(self) -> GroqModelName:
|
|
@@ -212,7 +218,7 @@ class GroqModel(Model):
|
|
|
212
218
|
model_request_parameters: ModelRequestParameters,
|
|
213
219
|
) -> chat.ChatCompletion | AsyncStream[chat.ChatCompletionChunk]:
|
|
214
220
|
tools = self._get_tools(model_request_parameters)
|
|
215
|
-
|
|
221
|
+
tools += self._get_builtin_tools(model_request_parameters)
|
|
216
222
|
if not tools:
|
|
217
223
|
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
218
224
|
elif not model_request_parameters.allow_text_output:
|
|
@@ -226,7 +232,7 @@ class GroqModel(Model):
|
|
|
226
232
|
extra_headers = model_settings.get('extra_headers', {})
|
|
227
233
|
extra_headers.setdefault('User-Agent', get_user_agent())
|
|
228
234
|
return await self.client.chat.completions.create(
|
|
229
|
-
model=
|
|
235
|
+
model=self._model_name,
|
|
230
236
|
messages=groq_messages,
|
|
231
237
|
n=1,
|
|
232
238
|
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
@@ -256,6 +262,19 @@ class GroqModel(Model):
|
|
|
256
262
|
timestamp = number_to_datetime(response.created)
|
|
257
263
|
choice = response.choices[0]
|
|
258
264
|
items: list[ModelResponsePart] = []
|
|
265
|
+
if choice.message.executed_tools:
|
|
266
|
+
for tool in choice.message.executed_tools:
|
|
267
|
+
tool_call_id = generate_tool_call_id()
|
|
268
|
+
items.append(
|
|
269
|
+
BuiltinToolCallPart(
|
|
270
|
+
tool_name=tool.type, args=tool.arguments, provider_name='groq', tool_call_id=tool_call_id
|
|
271
|
+
)
|
|
272
|
+
)
|
|
273
|
+
items.append(
|
|
274
|
+
BuiltinToolReturnPart(
|
|
275
|
+
provider_name='groq', tool_name=tool.type, content=tool.output, tool_call_id=tool_call_id
|
|
276
|
+
)
|
|
277
|
+
)
|
|
259
278
|
# NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
|
|
260
279
|
if choice.message.reasoning is not None:
|
|
261
280
|
items.append(ThinkingPart(content=choice.message.reasoning))
|
|
@@ -269,7 +288,9 @@ class GroqModel(Model):
|
|
|
269
288
|
items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
|
|
270
289
|
)
|
|
271
290
|
|
|
272
|
-
async def _process_streamed_response(
|
|
291
|
+
async def _process_streamed_response(
|
|
292
|
+
self, response: AsyncStream[chat.ChatCompletionChunk], model_request_parameters: ModelRequestParameters
|
|
293
|
+
) -> GroqStreamedResponse:
|
|
273
294
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
274
295
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
275
296
|
first_chunk = await peekable_response.peek()
|
|
@@ -279,6 +300,7 @@ class GroqModel(Model):
|
|
|
279
300
|
)
|
|
280
301
|
|
|
281
302
|
return GroqStreamedResponse(
|
|
303
|
+
model_request_parameters=model_request_parameters,
|
|
282
304
|
_response=peekable_response,
|
|
283
305
|
_model_name=self._model_name,
|
|
284
306
|
_model_profile=self.profile,
|
|
@@ -286,9 +308,20 @@ class GroqModel(Model):
|
|
|
286
308
|
)
|
|
287
309
|
|
|
288
310
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
311
|
+
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
312
|
+
|
|
313
|
+
def _get_builtin_tools(
|
|
314
|
+
self, model_request_parameters: ModelRequestParameters
|
|
315
|
+
) -> list[chat.ChatCompletionToolParam]:
|
|
316
|
+
tools: list[chat.ChatCompletionToolParam] = []
|
|
317
|
+
for tool in model_request_parameters.builtin_tools:
|
|
318
|
+
if isinstance(tool, WebSearchTool):
|
|
319
|
+
if not GroqModelProfile.from_profile(self.profile).groq_always_has_web_search_builtin_tool:
|
|
320
|
+
raise UserError('`WebSearchTool` is not supported by Groq') # pragma: no cover
|
|
321
|
+
else:
|
|
322
|
+
raise UserError(
|
|
323
|
+
f'`{tool.__class__.__name__}` is not supported by `GroqModel`. If it should be, please file an issue.'
|
|
324
|
+
)
|
|
292
325
|
return tools
|
|
293
326
|
|
|
294
327
|
def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
|
|
@@ -308,6 +341,9 @@ class GroqModel(Model):
|
|
|
308
341
|
elif isinstance(item, ThinkingPart):
|
|
309
342
|
# Skip thinking parts when mapping to Groq messages
|
|
310
343
|
continue
|
|
344
|
+
elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
|
|
345
|
+
# This is currently never returned from groq
|
|
346
|
+
pass
|
|
311
347
|
else:
|
|
312
348
|
assert_never(item)
|
|
313
349
|
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
|
|
@@ -5,18 +5,20 @@ from collections.abc import AsyncIterable, AsyncIterator
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime, timezone
|
|
8
|
-
from typing import Literal, Union, cast, overload
|
|
8
|
+
from typing import Any, Literal, Union, cast, overload
|
|
9
9
|
|
|
10
10
|
from typing_extensions import assert_never
|
|
11
11
|
|
|
12
|
-
from pydantic_ai._thinking_part import split_content_into_text_and_thinking
|
|
13
|
-
from pydantic_ai.providers import Provider, infer_provider
|
|
14
|
-
|
|
15
12
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
13
|
+
from .._run_context import RunContext
|
|
14
|
+
from .._thinking_part import split_content_into_text_and_thinking
|
|
16
15
|
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc
|
|
16
|
+
from ..exceptions import UserError
|
|
17
17
|
from ..messages import (
|
|
18
18
|
AudioUrl,
|
|
19
19
|
BinaryContent,
|
|
20
|
+
BuiltinToolCallPart,
|
|
21
|
+
BuiltinToolReturnPart,
|
|
20
22
|
DocumentUrl,
|
|
21
23
|
ImageUrl,
|
|
22
24
|
ModelMessage,
|
|
@@ -34,9 +36,15 @@ from ..messages import (
|
|
|
34
36
|
VideoUrl,
|
|
35
37
|
)
|
|
36
38
|
from ..profiles import ModelProfile
|
|
39
|
+
from ..providers import Provider, infer_provider
|
|
37
40
|
from ..settings import ModelSettings
|
|
38
41
|
from ..tools import ToolDefinition
|
|
39
|
-
from . import
|
|
42
|
+
from . import (
|
|
43
|
+
Model,
|
|
44
|
+
ModelRequestParameters,
|
|
45
|
+
StreamedResponse,
|
|
46
|
+
check_allow_model_requests,
|
|
47
|
+
)
|
|
40
48
|
|
|
41
49
|
try:
|
|
42
50
|
import aiohttp
|
|
@@ -147,12 +155,13 @@ class HuggingFaceModel(Model):
|
|
|
147
155
|
messages: list[ModelMessage],
|
|
148
156
|
model_settings: ModelSettings | None,
|
|
149
157
|
model_request_parameters: ModelRequestParameters,
|
|
158
|
+
run_context: RunContext[Any] | None = None,
|
|
150
159
|
) -> AsyncIterator[StreamedResponse]:
|
|
151
160
|
check_allow_model_requests()
|
|
152
161
|
response = await self._completions_create(
|
|
153
162
|
messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
|
|
154
163
|
)
|
|
155
|
-
yield await self._process_streamed_response(response)
|
|
164
|
+
yield await self._process_streamed_response(response, model_request_parameters)
|
|
156
165
|
|
|
157
166
|
@property
|
|
158
167
|
def model_name(self) -> HuggingFaceModelName:
|
|
@@ -198,6 +207,9 @@ class HuggingFaceModel(Model):
|
|
|
198
207
|
else:
|
|
199
208
|
tool_choice = 'auto'
|
|
200
209
|
|
|
210
|
+
if model_request_parameters.builtin_tools:
|
|
211
|
+
raise UserError('HuggingFace does not support built-in tools')
|
|
212
|
+
|
|
201
213
|
hf_messages = await self._map_messages(messages)
|
|
202
214
|
|
|
203
215
|
try:
|
|
@@ -257,7 +269,9 @@ class HuggingFaceModel(Model):
|
|
|
257
269
|
vendor_id=response.id,
|
|
258
270
|
)
|
|
259
271
|
|
|
260
|
-
async def _process_streamed_response(
|
|
272
|
+
async def _process_streamed_response(
|
|
273
|
+
self, response: AsyncIterable[ChatCompletionStreamOutput], model_request_parameters: ModelRequestParameters
|
|
274
|
+
) -> StreamedResponse:
|
|
261
275
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
262
276
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
263
277
|
first_chunk = await peekable_response.peek()
|
|
@@ -267,6 +281,7 @@ class HuggingFaceModel(Model):
|
|
|
267
281
|
)
|
|
268
282
|
|
|
269
283
|
return HuggingFaceStreamedResponse(
|
|
284
|
+
model_request_parameters=model_request_parameters,
|
|
270
285
|
_model_name=self._model_name,
|
|
271
286
|
_model_profile=self.profile,
|
|
272
287
|
_response=peekable_response,
|
|
@@ -274,10 +289,7 @@ class HuggingFaceModel(Model):
|
|
|
274
289
|
)
|
|
275
290
|
|
|
276
291
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]:
|
|
277
|
-
|
|
278
|
-
if model_request_parameters.output_tools:
|
|
279
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
280
|
-
return tools
|
|
292
|
+
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
281
293
|
|
|
282
294
|
async def _map_messages(
|
|
283
295
|
self, messages: list[ModelMessage]
|
|
@@ -301,6 +313,9 @@ class HuggingFaceModel(Model):
|
|
|
301
313
|
# please open an issue. The below code is the code to send thinking to the provider.
|
|
302
314
|
# texts.append(f'<think>\n{item.content}\n</think>')
|
|
303
315
|
pass
|
|
316
|
+
elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
|
|
317
|
+
# This is currently never returned from huggingface
|
|
318
|
+
pass
|
|
304
319
|
else:
|
|
305
320
|
assert_never(item)
|
|
306
321
|
message_param = ChatCompletionInputMessage(role='assistant') # type: ignore
|
|
@@ -18,6 +18,7 @@ from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provide
|
|
|
18
18
|
from opentelemetry.util.types import AttributeValue
|
|
19
19
|
from pydantic import TypeAdapter
|
|
20
20
|
|
|
21
|
+
from .._run_context import RunContext
|
|
21
22
|
from ..messages import ModelMessage, ModelRequest, ModelResponse
|
|
22
23
|
from ..settings import ModelSettings
|
|
23
24
|
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
|
|
@@ -218,12 +219,13 @@ class InstrumentedModel(WrapperModel):
|
|
|
218
219
|
messages: list[ModelMessage],
|
|
219
220
|
model_settings: ModelSettings | None,
|
|
220
221
|
model_request_parameters: ModelRequestParameters,
|
|
222
|
+
run_context: RunContext[Any] | None = None,
|
|
221
223
|
) -> AsyncIterator[StreamedResponse]:
|
|
222
224
|
with self._instrument(messages, model_settings, model_request_parameters) as finish:
|
|
223
225
|
response_stream: StreamedResponse | None = None
|
|
224
226
|
try:
|
|
225
227
|
async with super().request_stream(
|
|
226
|
-
messages, model_settings, model_request_parameters
|
|
228
|
+
messages, model_settings, model_request_parameters, run_context
|
|
227
229
|
) as response_stream:
|
|
228
230
|
yield response_stream
|
|
229
231
|
finally:
|
|
@@ -3,9 +3,10 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
from collections.abc import AsyncIterator
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass
|
|
6
|
-
from typing import TYPE_CHECKING, cast
|
|
6
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
7
7
|
|
|
8
8
|
from .. import _mcp, exceptions, usage
|
|
9
|
+
from .._run_context import RunContext
|
|
9
10
|
from ..messages import ModelMessage, ModelResponse
|
|
10
11
|
from ..settings import ModelSettings
|
|
11
12
|
from . import Model, ModelRequestParameters, StreamedResponse
|
|
@@ -76,6 +77,7 @@ class MCPSamplingModel(Model):
|
|
|
76
77
|
messages: list[ModelMessage],
|
|
77
78
|
model_settings: ModelSettings | None,
|
|
78
79
|
model_request_parameters: ModelRequestParameters,
|
|
80
|
+
run_context: RunContext[Any] | None = None,
|
|
79
81
|
) -> AsyncIterator[StreamedResponse]:
|
|
80
82
|
raise NotImplementedError('MCP Sampling does not support streaming')
|
|
81
83
|
yield
|