pydantic-ai-slim 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_parts_manager.py +1 -1
- pydantic_ai/_pydantic.py +1 -0
- pydantic_ai/_result.py +29 -28
- pydantic_ai/_system_prompt.py +4 -4
- pydantic_ai/_utils.py +1 -56
- pydantic_ai/agent.py +137 -113
- pydantic_ai/messages.py +24 -56
- pydantic_ai/models/__init__.py +122 -51
- pydantic_ai/models/anthropic.py +109 -38
- pydantic_ai/models/cohere.py +290 -0
- pydantic_ai/models/function.py +12 -8
- pydantic_ai/models/gemini.py +29 -15
- pydantic_ai/models/groq.py +27 -23
- pydantic_ai/models/mistral.py +34 -29
- pydantic_ai/models/openai.py +45 -23
- pydantic_ai/models/test.py +47 -24
- pydantic_ai/models/vertexai.py +2 -1
- pydantic_ai/result.py +45 -26
- pydantic_ai/settings.py +58 -1
- pydantic_ai/tools.py +29 -26
- {pydantic_ai_slim-0.0.19.dist-info → pydantic_ai_slim-0.0.21.dist-info}/METADATA +6 -4
- pydantic_ai_slim-0.0.21.dist-info/RECORD +29 -0
- pydantic_ai/models/ollama.py +0 -120
- pydantic_ai_slim-0.0.19.dist-info/RECORD +0 -29
- {pydantic_ai_slim-0.0.19.dist-info → pydantic_ai_slim-0.0.21.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from itertools import chain
|
|
6
|
+
from typing import Literal, Union, cast
|
|
7
|
+
|
|
8
|
+
from cohere import TextAssistantMessageContentItem
|
|
9
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
10
|
+
from typing_extensions import assert_never
|
|
11
|
+
|
|
12
|
+
from .. import result
|
|
13
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
14
|
+
from ..messages import (
|
|
15
|
+
ModelMessage,
|
|
16
|
+
ModelRequest,
|
|
17
|
+
ModelResponse,
|
|
18
|
+
ModelResponsePart,
|
|
19
|
+
RetryPromptPart,
|
|
20
|
+
SystemPromptPart,
|
|
21
|
+
TextPart,
|
|
22
|
+
ToolCallPart,
|
|
23
|
+
ToolReturnPart,
|
|
24
|
+
UserPromptPart,
|
|
25
|
+
)
|
|
26
|
+
from ..settings import ModelSettings
|
|
27
|
+
from ..tools import ToolDefinition
|
|
28
|
+
from . import (
|
|
29
|
+
AgentModel,
|
|
30
|
+
Model,
|
|
31
|
+
check_allow_model_requests,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from cohere import (
|
|
36
|
+
AssistantChatMessageV2,
|
|
37
|
+
AsyncClientV2,
|
|
38
|
+
ChatMessageV2,
|
|
39
|
+
ChatResponse,
|
|
40
|
+
SystemChatMessageV2,
|
|
41
|
+
ToolCallV2,
|
|
42
|
+
ToolCallV2Function,
|
|
43
|
+
ToolChatMessageV2,
|
|
44
|
+
ToolV2,
|
|
45
|
+
ToolV2Function,
|
|
46
|
+
UserChatMessageV2,
|
|
47
|
+
)
|
|
48
|
+
from cohere.v2.client import OMIT
|
|
49
|
+
except ImportError as _import_error:
|
|
50
|
+
raise ImportError(
|
|
51
|
+
'Please install `cohere` to use the Cohere model, '
|
|
52
|
+
"you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
|
|
53
|
+
) from _import_error
|
|
54
|
+
|
|
55
|
+
NamedCohereModels = Literal[
|
|
56
|
+
'c4ai-aya-expanse-32b',
|
|
57
|
+
'c4ai-aya-expanse-8b',
|
|
58
|
+
'command',
|
|
59
|
+
'command-light',
|
|
60
|
+
'command-light-nightly',
|
|
61
|
+
'command-nightly',
|
|
62
|
+
'command-r',
|
|
63
|
+
'command-r-03-2024',
|
|
64
|
+
'command-r-08-2024',
|
|
65
|
+
'command-r-plus',
|
|
66
|
+
'command-r-plus-04-2024',
|
|
67
|
+
'command-r-plus-08-2024',
|
|
68
|
+
'command-r7b-12-2024',
|
|
69
|
+
]
|
|
70
|
+
"""Latest / most popular named Cohere models."""
|
|
71
|
+
|
|
72
|
+
CohereModelName = Union[NamedCohereModels, str]
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class CohereModelSettings(ModelSettings):
|
|
76
|
+
"""Settings used for a Cohere model request."""
|
|
77
|
+
|
|
78
|
+
# This class is a placeholder for any future cohere-specific settings
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass(init=False)
|
|
82
|
+
class CohereModel(Model):
|
|
83
|
+
"""A model that uses the Cohere API.
|
|
84
|
+
|
|
85
|
+
Internally, this uses the [Cohere Python client](
|
|
86
|
+
https://github.com/cohere-ai/cohere-python) to interact with the API.
|
|
87
|
+
|
|
88
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
model_name: CohereModelName
|
|
92
|
+
client: AsyncClientV2 = field(repr=False)
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
model_name: CohereModelName,
|
|
97
|
+
*,
|
|
98
|
+
api_key: str | None = None,
|
|
99
|
+
cohere_client: AsyncClientV2 | None = None,
|
|
100
|
+
http_client: AsyncHTTPClient | None = None,
|
|
101
|
+
):
|
|
102
|
+
"""Initialize an Cohere model.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
model_name: The name of the Cohere model to use. List of model names
|
|
106
|
+
available [here](https://docs.cohere.com/docs/models#command).
|
|
107
|
+
api_key: The API key to use for authentication, if not provided, the
|
|
108
|
+
`CO_API_KEY` environment variable will be used if available.
|
|
109
|
+
cohere_client: An existing Cohere async client to use. If provided,
|
|
110
|
+
`api_key` and `http_client` must be `None`.
|
|
111
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
112
|
+
"""
|
|
113
|
+
self.model_name: CohereModelName = model_name
|
|
114
|
+
if cohere_client is not None:
|
|
115
|
+
assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
|
|
116
|
+
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
|
|
117
|
+
self.client = cohere_client
|
|
118
|
+
else:
|
|
119
|
+
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore
|
|
120
|
+
|
|
121
|
+
async def agent_model(
|
|
122
|
+
self,
|
|
123
|
+
*,
|
|
124
|
+
function_tools: list[ToolDefinition],
|
|
125
|
+
allow_text_result: bool,
|
|
126
|
+
result_tools: list[ToolDefinition],
|
|
127
|
+
) -> AgentModel:
|
|
128
|
+
check_allow_model_requests()
|
|
129
|
+
tools = [self._map_tool_definition(r) for r in function_tools]
|
|
130
|
+
if result_tools:
|
|
131
|
+
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
132
|
+
return CohereAgentModel(
|
|
133
|
+
self.client,
|
|
134
|
+
self.model_name,
|
|
135
|
+
allow_text_result,
|
|
136
|
+
tools,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def name(self) -> str:
|
|
140
|
+
return f'cohere:{self.model_name}'
|
|
141
|
+
|
|
142
|
+
@staticmethod
|
|
143
|
+
def _map_tool_definition(f: ToolDefinition) -> ToolV2:
|
|
144
|
+
return ToolV2(
|
|
145
|
+
type='function',
|
|
146
|
+
function=ToolV2Function(
|
|
147
|
+
name=f.name,
|
|
148
|
+
description=f.description,
|
|
149
|
+
parameters=f.parameters_json_schema,
|
|
150
|
+
),
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@dataclass
|
|
155
|
+
class CohereAgentModel(AgentModel):
|
|
156
|
+
"""Implementation of `AgentModel` for Cohere models."""
|
|
157
|
+
|
|
158
|
+
client: AsyncClientV2
|
|
159
|
+
model_name: CohereModelName
|
|
160
|
+
allow_text_result: bool
|
|
161
|
+
tools: list[ToolV2]
|
|
162
|
+
|
|
163
|
+
async def request(
|
|
164
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
165
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
166
|
+
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}))
|
|
167
|
+
return self._process_response(response), _map_usage(response)
|
|
168
|
+
|
|
169
|
+
async def _chat(
|
|
170
|
+
self,
|
|
171
|
+
messages: list[ModelMessage],
|
|
172
|
+
model_settings: CohereModelSettings,
|
|
173
|
+
) -> ChatResponse:
|
|
174
|
+
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
175
|
+
return await self.client.chat(
|
|
176
|
+
model=self.model_name,
|
|
177
|
+
messages=cohere_messages,
|
|
178
|
+
tools=self.tools or OMIT,
|
|
179
|
+
max_tokens=model_settings.get('max_tokens', OMIT),
|
|
180
|
+
temperature=model_settings.get('temperature', OMIT),
|
|
181
|
+
p=model_settings.get('top_p', OMIT),
|
|
182
|
+
seed=model_settings.get('seed', OMIT),
|
|
183
|
+
presence_penalty=model_settings.get('presence_penalty', OMIT),
|
|
184
|
+
frequency_penalty=model_settings.get('frequency_penalty', OMIT),
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def _process_response(self, response: ChatResponse) -> ModelResponse:
|
|
188
|
+
"""Process a non-streamed response, and prepare a message to return."""
|
|
189
|
+
parts: list[ModelResponsePart] = []
|
|
190
|
+
if response.message.content is not None and len(response.message.content) > 0:
|
|
191
|
+
# While Cohere's API returns a list, it only does that for future proofing
|
|
192
|
+
# and currently only one item is being returned.
|
|
193
|
+
choice = response.message.content[0]
|
|
194
|
+
parts.append(TextPart(choice.text))
|
|
195
|
+
for c in response.message.tool_calls or []:
|
|
196
|
+
if c.function and c.function.name and c.function.arguments:
|
|
197
|
+
parts.append(
|
|
198
|
+
ToolCallPart(
|
|
199
|
+
tool_name=c.function.name,
|
|
200
|
+
args=c.function.arguments,
|
|
201
|
+
tool_call_id=c.id,
|
|
202
|
+
)
|
|
203
|
+
)
|
|
204
|
+
return ModelResponse(parts=parts, model_name=self.model_name)
|
|
205
|
+
|
|
206
|
+
@classmethod
|
|
207
|
+
def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]:
|
|
208
|
+
"""Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
|
|
209
|
+
if isinstance(message, ModelRequest):
|
|
210
|
+
yield from cls._map_user_message(message)
|
|
211
|
+
elif isinstance(message, ModelResponse):
|
|
212
|
+
texts: list[str] = []
|
|
213
|
+
tool_calls: list[ToolCallV2] = []
|
|
214
|
+
for item in message.parts:
|
|
215
|
+
if isinstance(item, TextPart):
|
|
216
|
+
texts.append(item.content)
|
|
217
|
+
elif isinstance(item, ToolCallPart):
|
|
218
|
+
tool_calls.append(_map_tool_call(item))
|
|
219
|
+
else:
|
|
220
|
+
assert_never(item)
|
|
221
|
+
message_param = AssistantChatMessageV2(role='assistant')
|
|
222
|
+
if texts:
|
|
223
|
+
message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))]
|
|
224
|
+
if tool_calls:
|
|
225
|
+
message_param.tool_calls = tool_calls
|
|
226
|
+
yield message_param
|
|
227
|
+
else:
|
|
228
|
+
assert_never(message)
|
|
229
|
+
|
|
230
|
+
@classmethod
|
|
231
|
+
def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
|
|
232
|
+
for part in message.parts:
|
|
233
|
+
if isinstance(part, SystemPromptPart):
|
|
234
|
+
yield SystemChatMessageV2(role='system', content=part.content)
|
|
235
|
+
elif isinstance(part, UserPromptPart):
|
|
236
|
+
yield UserChatMessageV2(role='user', content=part.content)
|
|
237
|
+
elif isinstance(part, ToolReturnPart):
|
|
238
|
+
yield ToolChatMessageV2(
|
|
239
|
+
role='tool',
|
|
240
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
|
|
241
|
+
content=part.model_response_str(),
|
|
242
|
+
)
|
|
243
|
+
elif isinstance(part, RetryPromptPart):
|
|
244
|
+
if part.tool_name is None:
|
|
245
|
+
yield UserChatMessageV2(role='user', content=part.model_response())
|
|
246
|
+
else:
|
|
247
|
+
yield ToolChatMessageV2(
|
|
248
|
+
role='tool',
|
|
249
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
|
|
250
|
+
content=part.model_response(),
|
|
251
|
+
)
|
|
252
|
+
else:
|
|
253
|
+
assert_never(part)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
|
|
257
|
+
return ToolCallV2(
|
|
258
|
+
id=_guard_tool_call_id(t=t, model_source='Cohere'),
|
|
259
|
+
type='function',
|
|
260
|
+
function=ToolCallV2Function(
|
|
261
|
+
name=t.tool_name,
|
|
262
|
+
arguments=t.args_as_json_str(),
|
|
263
|
+
),
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _map_usage(response: ChatResponse) -> result.Usage:
|
|
268
|
+
usage = response.usage
|
|
269
|
+
if usage is None:
|
|
270
|
+
return result.Usage()
|
|
271
|
+
else:
|
|
272
|
+
details: dict[str, int] = {}
|
|
273
|
+
if usage.billed_units is not None:
|
|
274
|
+
if usage.billed_units.input_tokens:
|
|
275
|
+
details['input_tokens'] = int(usage.billed_units.input_tokens)
|
|
276
|
+
if usage.billed_units.output_tokens:
|
|
277
|
+
details['output_tokens'] = int(usage.billed_units.output_tokens)
|
|
278
|
+
if usage.billed_units.search_units:
|
|
279
|
+
details['search_units'] = int(usage.billed_units.search_units)
|
|
280
|
+
if usage.billed_units.classifications:
|
|
281
|
+
details['classifications'] = int(usage.billed_units.classifications)
|
|
282
|
+
|
|
283
|
+
request_tokens = int(usage.tokens.input_tokens) if usage.tokens and usage.tokens.input_tokens else None
|
|
284
|
+
response_tokens = int(usage.tokens.output_tokens) if usage.tokens and usage.tokens.output_tokens else None
|
|
285
|
+
return result.Usage(
|
|
286
|
+
request_tokens=request_tokens,
|
|
287
|
+
response_tokens=response_tokens,
|
|
288
|
+
total_tokens=(request_tokens or 0) + (response_tokens or 0),
|
|
289
|
+
details=details,
|
|
290
|
+
)
|
pydantic_ai/models/function.py
CHANGED
|
@@ -71,16 +71,15 @@ class FunctionModel(Model):
|
|
|
71
71
|
result_tools: list[ToolDefinition],
|
|
72
72
|
) -> AgentModel:
|
|
73
73
|
return FunctionAgentModel(
|
|
74
|
-
self.function,
|
|
74
|
+
self.function,
|
|
75
|
+
self.stream_function,
|
|
76
|
+
AgentInfo(function_tools, allow_text_result, result_tools, None),
|
|
75
77
|
)
|
|
76
78
|
|
|
77
79
|
def name(self) -> str:
|
|
78
|
-
|
|
79
|
-
if self.
|
|
80
|
-
|
|
81
|
-
if self.stream_function is not None:
|
|
82
|
-
labels.append(f'stream-{self.stream_function.__name__}')
|
|
83
|
-
return f'function:{",".join(labels)}'
|
|
80
|
+
function_name = self.function.__name__ if self.function is not None else ''
|
|
81
|
+
stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
|
|
82
|
+
return f'function:{function_name}:{stream_function_name}'
|
|
84
83
|
|
|
85
84
|
|
|
86
85
|
@dataclass(frozen=True)
|
|
@@ -147,12 +146,15 @@ class FunctionAgentModel(AgentModel):
|
|
|
147
146
|
agent_info = replace(self.agent_info, model_settings=model_settings)
|
|
148
147
|
|
|
149
148
|
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
149
|
+
model_name = f'function:{self.function.__name__}'
|
|
150
|
+
|
|
150
151
|
if inspect.iscoroutinefunction(self.function):
|
|
151
152
|
response = await self.function(messages, agent_info)
|
|
152
153
|
else:
|
|
153
154
|
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
|
|
154
155
|
assert isinstance(response_, ModelResponse), response_
|
|
155
156
|
response = response_
|
|
157
|
+
response.model_name = model_name
|
|
156
158
|
# TODO is `messages` right here? Should it just be new messages?
|
|
157
159
|
return response, _estimate_usage(chain(messages, [response]))
|
|
158
160
|
|
|
@@ -163,13 +165,15 @@ class FunctionAgentModel(AgentModel):
|
|
|
163
165
|
assert (
|
|
164
166
|
self.stream_function is not None
|
|
165
167
|
), 'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
168
|
+
model_name = f'function:{self.stream_function.__name__}'
|
|
169
|
+
|
|
166
170
|
response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info))
|
|
167
171
|
|
|
168
172
|
first = await response_stream.peek()
|
|
169
173
|
if isinstance(first, _utils.Unset):
|
|
170
174
|
raise ValueError('Stream function must return at least one item')
|
|
171
175
|
|
|
172
|
-
yield FunctionStreamedResponse(response_stream)
|
|
176
|
+
yield FunctionStreamedResponse(_model_name=model_name, _iter=response_stream)
|
|
173
177
|
|
|
174
178
|
|
|
175
179
|
@dataclass
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -7,7 +7,7 @@ from contextlib import asynccontextmanager
|
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from dataclasses import dataclass, field
|
|
9
9
|
from datetime import datetime
|
|
10
|
-
from typing import Annotated, Any, Literal, Protocol, Union
|
|
10
|
+
from typing import Annotated, Any, Literal, Protocol, Union, cast
|
|
11
11
|
from uuid import uuid4
|
|
12
12
|
|
|
13
13
|
import pydantic
|
|
@@ -48,6 +48,12 @@ See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#mo
|
|
|
48
48
|
"""
|
|
49
49
|
|
|
50
50
|
|
|
51
|
+
class GeminiModelSettings(ModelSettings):
|
|
52
|
+
"""Settings used for a Gemini model request."""
|
|
53
|
+
|
|
54
|
+
# This class is a placeholder for any future gemini-specific settings
|
|
55
|
+
|
|
56
|
+
|
|
51
57
|
@dataclass(init=False)
|
|
52
58
|
class GeminiModel(Model):
|
|
53
59
|
"""A model that uses Gemini via `generativelanguage.googleapis.com` API.
|
|
@@ -99,6 +105,7 @@ class GeminiModel(Model):
|
|
|
99
105
|
allow_text_result: bool,
|
|
100
106
|
result_tools: list[ToolDefinition],
|
|
101
107
|
) -> GeminiAgentModel:
|
|
108
|
+
check_allow_model_requests()
|
|
102
109
|
return GeminiAgentModel(
|
|
103
110
|
http_client=self.http_client,
|
|
104
111
|
model_name=self.model_name,
|
|
@@ -151,7 +158,6 @@ class GeminiAgentModel(AgentModel):
|
|
|
151
158
|
allow_text_result: bool,
|
|
152
159
|
result_tools: list[ToolDefinition],
|
|
153
160
|
):
|
|
154
|
-
check_allow_model_requests()
|
|
155
161
|
tools = [_function_from_abstract_tool(t) for t in function_tools]
|
|
156
162
|
if result_tools:
|
|
157
163
|
tools += [_function_from_abstract_tool(t) for t in result_tools]
|
|
@@ -171,7 +177,9 @@ class GeminiAgentModel(AgentModel):
|
|
|
171
177
|
async def request(
|
|
172
178
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
173
179
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
174
|
-
async with self._make_request(
|
|
180
|
+
async with self._make_request(
|
|
181
|
+
messages, False, cast(GeminiModelSettings, model_settings or {})
|
|
182
|
+
) as http_response:
|
|
175
183
|
response = _gemini_response_ta.validate_json(await http_response.aread())
|
|
176
184
|
return self._process_response(response), _metadata_as_usage(response)
|
|
177
185
|
|
|
@@ -179,12 +187,12 @@ class GeminiAgentModel(AgentModel):
|
|
|
179
187
|
async def request_stream(
|
|
180
188
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
181
189
|
) -> AsyncIterator[StreamedResponse]:
|
|
182
|
-
async with self._make_request(messages, True, model_settings) as http_response:
|
|
190
|
+
async with self._make_request(messages, True, cast(GeminiModelSettings, model_settings or {})) as http_response:
|
|
183
191
|
yield await self._process_streamed_response(http_response)
|
|
184
192
|
|
|
185
193
|
@asynccontextmanager
|
|
186
194
|
async def _make_request(
|
|
187
|
-
self, messages: list[ModelMessage], streamed: bool, model_settings:
|
|
195
|
+
self, messages: list[ModelMessage], streamed: bool, model_settings: GeminiModelSettings
|
|
188
196
|
) -> AsyncIterator[HTTPResponse]:
|
|
189
197
|
sys_prompt_parts, contents = self._message_to_gemini_content(messages)
|
|
190
198
|
|
|
@@ -204,6 +212,10 @@ class GeminiAgentModel(AgentModel):
|
|
|
204
212
|
generation_config['temperature'] = temperature
|
|
205
213
|
if (top_p := model_settings.get('top_p')) is not None:
|
|
206
214
|
generation_config['top_p'] = top_p
|
|
215
|
+
if (presence_penalty := model_settings.get('presence_penalty')) is not None:
|
|
216
|
+
generation_config['presence_penalty'] = presence_penalty
|
|
217
|
+
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
|
|
218
|
+
generation_config['frequency_penalty'] = frequency_penalty
|
|
207
219
|
if generation_config:
|
|
208
220
|
request_data['generation_config'] = generation_config
|
|
209
221
|
|
|
@@ -222,22 +234,20 @@ class GeminiAgentModel(AgentModel):
|
|
|
222
234
|
url,
|
|
223
235
|
content=request_json,
|
|
224
236
|
headers=headers,
|
|
225
|
-
timeout=
|
|
237
|
+
timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT),
|
|
226
238
|
) as r:
|
|
227
239
|
if r.status_code != 200:
|
|
228
240
|
await r.aread()
|
|
229
241
|
raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
|
|
230
242
|
yield r
|
|
231
243
|
|
|
232
|
-
|
|
233
|
-
def _process_response(response: _GeminiResponse) -> ModelResponse:
|
|
244
|
+
def _process_response(self, response: _GeminiResponse) -> ModelResponse:
|
|
234
245
|
if len(response['candidates']) != 1:
|
|
235
246
|
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
|
|
236
247
|
parts = response['candidates'][0]['content']['parts']
|
|
237
|
-
return _process_response_from_parts(parts)
|
|
248
|
+
return _process_response_from_parts(parts, model_name=self.model_name)
|
|
238
249
|
|
|
239
|
-
|
|
240
|
-
async def _process_streamed_response(http_response: HTTPResponse) -> StreamedResponse:
|
|
250
|
+
async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
|
|
241
251
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
242
252
|
aiter_bytes = http_response.aiter_bytes()
|
|
243
253
|
start_response: _GeminiResponse | None = None
|
|
@@ -258,7 +268,7 @@ class GeminiAgentModel(AgentModel):
|
|
|
258
268
|
if start_response is None:
|
|
259
269
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
260
270
|
|
|
261
|
-
return GeminiStreamedResponse(_content=content, _stream=aiter_bytes)
|
|
271
|
+
return GeminiStreamedResponse(_model_name=self.model_name, _content=content, _stream=aiter_bytes)
|
|
262
272
|
|
|
263
273
|
@classmethod
|
|
264
274
|
def _message_to_gemini_content(
|
|
@@ -400,6 +410,8 @@ class _GeminiGenerationConfig(TypedDict, total=False):
|
|
|
400
410
|
max_output_tokens: int
|
|
401
411
|
temperature: float
|
|
402
412
|
top_p: float
|
|
413
|
+
presence_penalty: float
|
|
414
|
+
frequency_penalty: float
|
|
403
415
|
|
|
404
416
|
|
|
405
417
|
class _GeminiContent(TypedDict):
|
|
@@ -432,14 +444,16 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
|
|
|
432
444
|
return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict()))
|
|
433
445
|
|
|
434
446
|
|
|
435
|
-
def _process_response_from_parts(
|
|
447
|
+
def _process_response_from_parts(
|
|
448
|
+
parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, timestamp: datetime | None = None
|
|
449
|
+
) -> ModelResponse:
|
|
436
450
|
items: list[ModelResponsePart] = []
|
|
437
451
|
for part in parts:
|
|
438
452
|
if 'text' in part:
|
|
439
453
|
items.append(TextPart(content=part['text']))
|
|
440
454
|
elif 'function_call' in part:
|
|
441
455
|
items.append(
|
|
442
|
-
ToolCallPart
|
|
456
|
+
ToolCallPart(
|
|
443
457
|
tool_name=part['function_call']['name'],
|
|
444
458
|
args=part['function_call']['args'],
|
|
445
459
|
)
|
|
@@ -448,7 +462,7 @@ def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: d
|
|
|
448
462
|
raise exceptions.UnexpectedModelBehavior(
|
|
449
463
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
450
464
|
)
|
|
451
|
-
return ModelResponse(items, timestamp=timestamp or _utils.now_utc())
|
|
465
|
+
return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc())
|
|
452
466
|
|
|
453
467
|
|
|
454
468
|
class _GeminiFunctionCall(TypedDict):
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
|
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
7
7
|
from itertools import chain
|
|
8
|
-
from typing import Literal, overload
|
|
8
|
+
from typing import Literal, cast, overload
|
|
9
9
|
|
|
10
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
11
11
|
from typing_extensions import assert_never
|
|
@@ -47,10 +47,7 @@ except ImportError as _import_error:
|
|
|
47
47
|
|
|
48
48
|
GroqModelName = Literal[
|
|
49
49
|
'llama-3.3-70b-versatile',
|
|
50
|
-
'llama-3.
|
|
51
|
-
'llama3-groq-70b-8192-tool-use-preview',
|
|
52
|
-
'llama3-groq-8b-8192-tool-use-preview',
|
|
53
|
-
'llama-3.1-70b-specdec',
|
|
50
|
+
'llama-3.3-70b-specdec',
|
|
54
51
|
'llama-3.1-8b-instant',
|
|
55
52
|
'llama-3.2-1b-preview',
|
|
56
53
|
'llama-3.2-3b-preview',
|
|
@@ -60,7 +57,6 @@ GroqModelName = Literal[
|
|
|
60
57
|
'llama3-8b-8192',
|
|
61
58
|
'mixtral-8x7b-32768',
|
|
62
59
|
'gemma2-9b-it',
|
|
63
|
-
'gemma-7b-it',
|
|
64
60
|
]
|
|
65
61
|
"""Named Groq models.
|
|
66
62
|
|
|
@@ -68,6 +64,12 @@ See [the Groq docs](https://console.groq.com/docs/models) for a full list.
|
|
|
68
64
|
"""
|
|
69
65
|
|
|
70
66
|
|
|
67
|
+
class GroqModelSettings(ModelSettings):
|
|
68
|
+
"""Settings used for a Groq model request."""
|
|
69
|
+
|
|
70
|
+
# This class is a placeholder for any future groq-specific settings
|
|
71
|
+
|
|
72
|
+
|
|
71
73
|
@dataclass(init=False)
|
|
72
74
|
class GroqModel(Model):
|
|
73
75
|
"""A model that uses the Groq API.
|
|
@@ -155,31 +157,31 @@ class GroqAgentModel(AgentModel):
|
|
|
155
157
|
async def request(
|
|
156
158
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
157
159
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
158
|
-
response = await self._completions_create(messages, False, model_settings)
|
|
160
|
+
response = await self._completions_create(messages, False, cast(GroqModelSettings, model_settings or {}))
|
|
159
161
|
return self._process_response(response), _map_usage(response)
|
|
160
162
|
|
|
161
163
|
@asynccontextmanager
|
|
162
164
|
async def request_stream(
|
|
163
165
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
164
166
|
) -> AsyncIterator[StreamedResponse]:
|
|
165
|
-
response = await self._completions_create(messages, True, model_settings)
|
|
167
|
+
response = await self._completions_create(messages, True, cast(GroqModelSettings, model_settings or {}))
|
|
166
168
|
async with response:
|
|
167
169
|
yield await self._process_streamed_response(response)
|
|
168
170
|
|
|
169
171
|
@overload
|
|
170
172
|
async def _completions_create(
|
|
171
|
-
self, messages: list[ModelMessage], stream: Literal[True], model_settings:
|
|
173
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: GroqModelSettings
|
|
172
174
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
173
175
|
pass
|
|
174
176
|
|
|
175
177
|
@overload
|
|
176
178
|
async def _completions_create(
|
|
177
|
-
self, messages: list[ModelMessage], stream: Literal[False], model_settings:
|
|
179
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: GroqModelSettings
|
|
178
180
|
) -> chat.ChatCompletion:
|
|
179
181
|
pass
|
|
180
182
|
|
|
181
183
|
async def _completions_create(
|
|
182
|
-
self, messages: list[ModelMessage], stream: bool, model_settings:
|
|
184
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings
|
|
183
185
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
184
186
|
# standalone function to make it easier to override
|
|
185
187
|
if not self.tools:
|
|
@@ -191,13 +193,11 @@ class GroqAgentModel(AgentModel):
|
|
|
191
193
|
|
|
192
194
|
groq_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
193
195
|
|
|
194
|
-
model_settings = model_settings or {}
|
|
195
|
-
|
|
196
196
|
return await self.client.chat.completions.create(
|
|
197
197
|
model=str(self.model_name),
|
|
198
198
|
messages=groq_messages,
|
|
199
199
|
n=1,
|
|
200
|
-
parallel_tool_calls=
|
|
200
|
+
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
201
201
|
tools=self.tools or NOT_GIVEN,
|
|
202
202
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
203
203
|
stream=stream,
|
|
@@ -205,10 +205,13 @@ class GroqAgentModel(AgentModel):
|
|
|
205
205
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
206
206
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
207
207
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
208
|
+
seed=model_settings.get('seed', NOT_GIVEN),
|
|
209
|
+
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
|
|
210
|
+
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
211
|
+
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
208
212
|
)
|
|
209
213
|
|
|
210
|
-
|
|
211
|
-
def _process_response(response: chat.ChatCompletion) -> ModelResponse:
|
|
214
|
+
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
212
215
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
213
216
|
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
214
217
|
choice = response.choices[0]
|
|
@@ -217,20 +220,21 @@ class GroqAgentModel(AgentModel):
|
|
|
217
220
|
items.append(TextPart(content=choice.message.content))
|
|
218
221
|
if choice.message.tool_calls is not None:
|
|
219
222
|
for c in choice.message.tool_calls:
|
|
220
|
-
items.append(
|
|
221
|
-
|
|
222
|
-
)
|
|
223
|
-
return ModelResponse(items, timestamp=timestamp)
|
|
223
|
+
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
224
|
+
return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
|
|
224
225
|
|
|
225
|
-
|
|
226
|
-
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
226
|
+
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
227
227
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
228
228
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
229
229
|
first_chunk = await peekable_response.peek()
|
|
230
230
|
if isinstance(first_chunk, _utils.Unset):
|
|
231
231
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
232
232
|
|
|
233
|
-
return GroqStreamedResponse(
|
|
233
|
+
return GroqStreamedResponse(
|
|
234
|
+
_response=peekable_response,
|
|
235
|
+
_model_name=self.model_name,
|
|
236
|
+
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
237
|
+
)
|
|
234
238
|
|
|
235
239
|
@classmethod
|
|
236
240
|
def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|