pydantic-ai-slim 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +13 -29
- pydantic_ai/_result.py +52 -38
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +20 -8
- pydantic_ai/agent.py +431 -167
- pydantic_ai/messages.py +90 -48
- pydantic_ai/models/__init__.py +59 -42
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +66 -44
- pydantic_ai/models/gemini.py +160 -117
- pydantic_ai/models/groq.py +125 -108
- pydantic_ai/models/mistral.py +680 -0
- pydantic_ai/models/ollama.py +116 -0
- pydantic_ai/models/openai.py +145 -114
- pydantic_ai/models/test.py +109 -77
- pydantic_ai/models/vertexai.py +14 -9
- pydantic_ai/result.py +35 -37
- pydantic_ai/settings.py +72 -0
- pydantic_ai/tools.py +140 -45
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.13.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.13.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.13.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.11.dist-info/RECORD +0 -22
pydantic_ai/models/groq.py
CHANGED
|
@@ -1,28 +1,34 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Iterable
|
|
3
|
+
from collections.abc import AsyncIterator, Iterable
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
7
|
+
from itertools import chain
|
|
7
8
|
from typing import Literal, overload
|
|
8
9
|
|
|
9
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
10
11
|
from typing_extensions import assert_never
|
|
11
12
|
|
|
12
13
|
from .. import UnexpectedModelBehavior, _utils, result
|
|
14
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
15
|
from ..messages import (
|
|
14
16
|
ArgsJson,
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
17
|
+
ModelMessage,
|
|
18
|
+
ModelRequest,
|
|
19
|
+
ModelResponse,
|
|
20
|
+
ModelResponsePart,
|
|
21
|
+
RetryPromptPart,
|
|
22
|
+
SystemPromptPart,
|
|
23
|
+
TextPart,
|
|
24
|
+
ToolCallPart,
|
|
25
|
+
ToolReturnPart,
|
|
26
|
+
UserPromptPart,
|
|
22
27
|
)
|
|
23
28
|
from ..result import Cost
|
|
29
|
+
from ..settings import ModelSettings
|
|
30
|
+
from ..tools import ToolDefinition
|
|
24
31
|
from . import (
|
|
25
|
-
AbstractToolDefinition,
|
|
26
32
|
AgentModel,
|
|
27
33
|
EitherStreamedResponse,
|
|
28
34
|
Model,
|
|
@@ -37,13 +43,14 @@ try:
|
|
|
37
43
|
from groq.types import chat
|
|
38
44
|
from groq.types.chat import ChatCompletion, ChatCompletionChunk
|
|
39
45
|
from groq.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
|
40
|
-
except ImportError as
|
|
46
|
+
except ImportError as _import_error:
|
|
41
47
|
raise ImportError(
|
|
42
48
|
'Please install `groq` to use the Groq model, '
|
|
43
|
-
"you can use the `groq` optional group — `pip install 'pydantic-ai[groq]'`"
|
|
44
|
-
) from
|
|
49
|
+
"you can use the `groq` optional group — `pip install 'pydantic-ai-slim[groq]'`"
|
|
50
|
+
) from _import_error
|
|
45
51
|
|
|
46
52
|
GroqModelName = Literal[
|
|
53
|
+
'llama-3.3-70b-versatile',
|
|
47
54
|
'llama-3.1-70b-versatile',
|
|
48
55
|
'llama3-groq-70b-8192-tool-use-preview',
|
|
49
56
|
'llama3-groq-8b-8192-tool-use-preview',
|
|
@@ -109,13 +116,14 @@ class GroqModel(Model):
|
|
|
109
116
|
|
|
110
117
|
async def agent_model(
|
|
111
118
|
self,
|
|
112
|
-
|
|
119
|
+
*,
|
|
120
|
+
function_tools: list[ToolDefinition],
|
|
113
121
|
allow_text_result: bool,
|
|
114
|
-
result_tools:
|
|
122
|
+
result_tools: list[ToolDefinition],
|
|
115
123
|
) -> AgentModel:
|
|
116
124
|
check_allow_model_requests()
|
|
117
|
-
tools = [self._map_tool_definition(r) for r in function_tools
|
|
118
|
-
if result_tools
|
|
125
|
+
tools = [self._map_tool_definition(r) for r in function_tools]
|
|
126
|
+
if result_tools:
|
|
119
127
|
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
120
128
|
return GroqAgentModel(
|
|
121
129
|
self.client,
|
|
@@ -128,13 +136,13 @@ class GroqModel(Model):
|
|
|
128
136
|
return f'groq:{self.model_name}'
|
|
129
137
|
|
|
130
138
|
@staticmethod
|
|
131
|
-
def _map_tool_definition(f:
|
|
139
|
+
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
132
140
|
return {
|
|
133
141
|
'type': 'function',
|
|
134
142
|
'function': {
|
|
135
143
|
'name': f.name,
|
|
136
144
|
'description': f.description,
|
|
137
|
-
'parameters': f.
|
|
145
|
+
'parameters': f.parameters_json_schema,
|
|
138
146
|
},
|
|
139
147
|
}
|
|
140
148
|
|
|
@@ -148,28 +156,34 @@ class GroqAgentModel(AgentModel):
|
|
|
148
156
|
allow_text_result: bool
|
|
149
157
|
tools: list[chat.ChatCompletionToolParam]
|
|
150
158
|
|
|
151
|
-
async def request(
|
|
152
|
-
|
|
159
|
+
async def request(
|
|
160
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
|
+
) -> tuple[ModelResponse, result.Cost]:
|
|
162
|
+
response = await self._completions_create(messages, False, model_settings)
|
|
153
163
|
return self._process_response(response), _map_cost(response)
|
|
154
164
|
|
|
155
165
|
@asynccontextmanager
|
|
156
|
-
async def request_stream(
|
|
157
|
-
|
|
166
|
+
async def request_stream(
|
|
167
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
168
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
169
|
+
response = await self._completions_create(messages, True, model_settings)
|
|
158
170
|
async with response:
|
|
159
171
|
yield await self._process_streamed_response(response)
|
|
160
172
|
|
|
161
173
|
@overload
|
|
162
174
|
async def _completions_create(
|
|
163
|
-
self, messages: list[
|
|
175
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
|
|
164
176
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
165
177
|
pass
|
|
166
178
|
|
|
167
179
|
@overload
|
|
168
|
-
async def _completions_create(
|
|
180
|
+
async def _completions_create(
|
|
181
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
|
|
182
|
+
) -> chat.ChatCompletion:
|
|
169
183
|
pass
|
|
170
184
|
|
|
171
185
|
async def _completions_create(
|
|
172
|
-
self, messages: list[
|
|
186
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
|
|
173
187
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
174
188
|
# standalone function to make it easier to override
|
|
175
189
|
if not self.tools:
|
|
@@ -179,103 +193,112 @@ class GroqAgentModel(AgentModel):
|
|
|
179
193
|
else:
|
|
180
194
|
tool_choice = 'auto'
|
|
181
195
|
|
|
182
|
-
groq_messages =
|
|
196
|
+
groq_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
197
|
+
|
|
198
|
+
model_settings = model_settings or {}
|
|
199
|
+
|
|
183
200
|
return await self.client.chat.completions.create(
|
|
184
201
|
model=str(self.model_name),
|
|
185
202
|
messages=groq_messages,
|
|
186
|
-
temperature=0.0,
|
|
187
203
|
n=1,
|
|
188
204
|
parallel_tool_calls=True if self.tools else NOT_GIVEN,
|
|
189
205
|
tools=self.tools or NOT_GIVEN,
|
|
190
206
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
191
207
|
stream=stream,
|
|
208
|
+
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
209
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
210
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
211
|
+
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
192
212
|
)
|
|
193
213
|
|
|
194
214
|
@staticmethod
|
|
195
|
-
def _process_response(response: chat.ChatCompletion) ->
|
|
215
|
+
def _process_response(response: chat.ChatCompletion) -> ModelResponse:
|
|
196
216
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
197
217
|
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
198
218
|
choice = response.choices[0]
|
|
219
|
+
items: list[ModelResponsePart] = []
|
|
220
|
+
if choice.message.content is not None:
|
|
221
|
+
items.append(TextPart(choice.message.content))
|
|
199
222
|
if choice.message.tool_calls is not None:
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
)
|
|
204
|
-
else:
|
|
205
|
-
assert choice.message.content is not None, choice
|
|
206
|
-
return ModelTextResponse(choice.message.content, timestamp=timestamp)
|
|
223
|
+
for c in choice.message.tool_calls:
|
|
224
|
+
items.append(ToolCallPart.from_json(c.function.name, c.function.arguments, c.id))
|
|
225
|
+
return ModelResponse(items, timestamp=timestamp)
|
|
207
226
|
|
|
208
227
|
@staticmethod
|
|
209
228
|
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
|
210
229
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
timestamp = datetime.fromtimestamp(first_chunk.created, tz=timezone.utc)
|
|
216
|
-
delta = first_chunk.choices[0].delta
|
|
217
|
-
start_cost = _map_cost(first_chunk)
|
|
218
|
-
|
|
219
|
-
# the first chunk may only contain `role`, so we iterate until we get either `tool_calls` or `content`
|
|
220
|
-
while delta.tool_calls is None and delta.content is None:
|
|
230
|
+
timestamp: datetime | None = None
|
|
231
|
+
start_cost = Cost()
|
|
232
|
+
# the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
|
|
233
|
+
while True:
|
|
221
234
|
try:
|
|
222
|
-
|
|
235
|
+
chunk = await response.__anext__()
|
|
223
236
|
except StopAsyncIteration as e:
|
|
224
237
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
225
|
-
|
|
226
|
-
start_cost += _map_cost(
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
238
|
+
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
239
|
+
start_cost += _map_cost(chunk)
|
|
240
|
+
|
|
241
|
+
if chunk.choices:
|
|
242
|
+
delta = chunk.choices[0].delta
|
|
243
|
+
|
|
244
|
+
if delta.content is not None:
|
|
245
|
+
return GroqStreamTextResponse(delta.content, response, timestamp, start_cost)
|
|
246
|
+
elif delta.tool_calls is not None:
|
|
247
|
+
return GroqStreamStructuredResponse(
|
|
248
|
+
response,
|
|
249
|
+
{c.index: c for c in delta.tool_calls},
|
|
250
|
+
timestamp,
|
|
251
|
+
start_cost,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
@classmethod
|
|
255
|
+
def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
256
|
+
"""Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
|
|
257
|
+
if isinstance(message, ModelRequest):
|
|
258
|
+
yield from cls._map_user_message(message)
|
|
259
|
+
elif isinstance(message, ModelResponse):
|
|
260
|
+
texts: list[str] = []
|
|
261
|
+
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
|
|
262
|
+
for item in message.parts:
|
|
263
|
+
if isinstance(item, TextPart):
|
|
264
|
+
texts.append(item.content)
|
|
265
|
+
elif isinstance(item, ToolCallPart):
|
|
266
|
+
tool_calls.append(_map_tool_call(item))
|
|
267
|
+
else:
|
|
268
|
+
assert_never(item)
|
|
269
|
+
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
|
|
270
|
+
if texts:
|
|
271
|
+
# Note: model responses from this model should only have one text item, so the following
|
|
272
|
+
# shouldn't merge multiple texts into one unless you switch models between runs:
|
|
273
|
+
message_param['content'] = '\n\n'.join(texts)
|
|
274
|
+
if tool_calls:
|
|
275
|
+
message_param['tool_calls'] = tool_calls
|
|
276
|
+
yield message_param
|
|
230
277
|
else:
|
|
231
|
-
|
|
232
|
-
return GroqStreamStructuredResponse(
|
|
233
|
-
response,
|
|
234
|
-
{c.index: c for c in delta.tool_calls},
|
|
235
|
-
timestamp,
|
|
236
|
-
start_cost,
|
|
237
|
-
)
|
|
278
|
+
assert_never(message)
|
|
238
279
|
|
|
239
|
-
@
|
|
240
|
-
def
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
elif message.role == 'tool-return':
|
|
249
|
-
# ToolReturn ->
|
|
250
|
-
return chat.ChatCompletionToolMessageParam(
|
|
251
|
-
role='tool',
|
|
252
|
-
tool_call_id=_guard_tool_id(message),
|
|
253
|
-
content=message.model_response_str(),
|
|
254
|
-
)
|
|
255
|
-
elif message.role == 'retry-prompt':
|
|
256
|
-
# RetryPrompt ->
|
|
257
|
-
if message.tool_name is None:
|
|
258
|
-
return chat.ChatCompletionUserMessageParam(role='user', content=message.model_response())
|
|
259
|
-
else:
|
|
260
|
-
return chat.ChatCompletionToolMessageParam(
|
|
280
|
+
@classmethod
|
|
281
|
+
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
282
|
+
for part in message.parts:
|
|
283
|
+
if isinstance(part, SystemPromptPart):
|
|
284
|
+
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
|
|
285
|
+
elif isinstance(part, UserPromptPart):
|
|
286
|
+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
|
|
287
|
+
elif isinstance(part, ToolReturnPart):
|
|
288
|
+
yield chat.ChatCompletionToolMessageParam(
|
|
261
289
|
role='tool',
|
|
262
|
-
tool_call_id=
|
|
263
|
-
content=
|
|
290
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
|
|
291
|
+
content=part.model_response_str(),
|
|
264
292
|
)
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
role='assistant',
|
|
275
|
-
tool_calls=[_map_tool_call(t) for t in message.calls],
|
|
276
|
-
)
|
|
277
|
-
else:
|
|
278
|
-
assert_never(message)
|
|
293
|
+
elif isinstance(part, RetryPromptPart):
|
|
294
|
+
if part.tool_name is None:
|
|
295
|
+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
|
|
296
|
+
else:
|
|
297
|
+
yield chat.ChatCompletionToolMessageParam(
|
|
298
|
+
role='tool',
|
|
299
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
|
|
300
|
+
content=part.model_response(),
|
|
301
|
+
)
|
|
279
302
|
|
|
280
303
|
|
|
281
304
|
@dataclass
|
|
@@ -352,14 +375,14 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
|
|
|
352
375
|
else:
|
|
353
376
|
self._delta_tool_calls[new.index] = new
|
|
354
377
|
|
|
355
|
-
def get(self, *, final: bool = False) ->
|
|
356
|
-
|
|
378
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
379
|
+
items: list[ModelResponsePart] = []
|
|
357
380
|
for c in self._delta_tool_calls.values():
|
|
358
381
|
if f := c.function:
|
|
359
382
|
if f.name is not None and f.arguments is not None:
|
|
360
|
-
|
|
383
|
+
items.append(ToolCallPart.from_json(f.name, f.arguments, c.id))
|
|
361
384
|
|
|
362
|
-
return
|
|
385
|
+
return ModelResponse(items, timestamp=self._timestamp)
|
|
363
386
|
|
|
364
387
|
def cost(self) -> Cost:
|
|
365
388
|
return self._cost
|
|
@@ -368,16 +391,10 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
|
|
|
368
391
|
return self._timestamp
|
|
369
392
|
|
|
370
393
|
|
|
371
|
-
def
|
|
372
|
-
"""Type guard that checks a `tool_id` is not None both for static typing and runtime."""
|
|
373
|
-
assert t.tool_id is not None, f'Groq requires `tool_id` to be set: {t}'
|
|
374
|
-
return t.tool_id
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
def _map_tool_call(t: ToolCall) -> chat.ChatCompletionMessageToolCallParam:
|
|
394
|
+
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
378
395
|
assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
|
|
379
396
|
return chat.ChatCompletionMessageToolCallParam(
|
|
380
|
-
id=
|
|
397
|
+
id=_guard_tool_call_id(t=t, model_source='Groq'),
|
|
381
398
|
type='function',
|
|
382
399
|
function={'name': t.tool_name, 'arguments': t.args.args_json},
|
|
383
400
|
)
|