pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +7 -25
- pydantic_ai/_result.py +34 -16
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +9 -2
- pydantic_ai/agent.py +333 -148
- pydantic_ai/messages.py +87 -48
- pydantic_ai/models/__init__.py +30 -6
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +59 -31
- pydantic_ai/models/gemini.py +150 -108
- pydantic_ai/models/groq.py +94 -74
- pydantic_ai/models/mistral.py +680 -0
- pydantic_ai/models/ollama.py +1 -1
- pydantic_ai/models/openai.py +102 -76
- pydantic_ai/models/test.py +62 -51
- pydantic_ai/models/vertexai.py +7 -3
- pydantic_ai/result.py +35 -37
- pydantic_ai/settings.py +72 -0
- pydantic_ai/tools.py +28 -18
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.13.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.13.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.13.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +0 -23
pydantic_ai/models/groq.py
CHANGED
|
@@ -4,23 +4,29 @@ from collections.abc import AsyncIterator, Iterable
|
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
7
|
+
from itertools import chain
|
|
7
8
|
from typing import Literal, overload
|
|
8
9
|
|
|
9
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
10
11
|
from typing_extensions import assert_never
|
|
11
12
|
|
|
12
13
|
from .. import UnexpectedModelBehavior, _utils, result
|
|
14
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
15
|
from ..messages import (
|
|
14
16
|
ArgsJson,
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
17
|
+
ModelMessage,
|
|
18
|
+
ModelRequest,
|
|
19
|
+
ModelResponse,
|
|
20
|
+
ModelResponsePart,
|
|
21
|
+
RetryPromptPart,
|
|
22
|
+
SystemPromptPart,
|
|
23
|
+
TextPart,
|
|
24
|
+
ToolCallPart,
|
|
25
|
+
ToolReturnPart,
|
|
26
|
+
UserPromptPart,
|
|
22
27
|
)
|
|
23
28
|
from ..result import Cost
|
|
29
|
+
from ..settings import ModelSettings
|
|
24
30
|
from ..tools import ToolDefinition
|
|
25
31
|
from . import (
|
|
26
32
|
AgentModel,
|
|
@@ -40,10 +46,11 @@ try:
|
|
|
40
46
|
except ImportError as _import_error:
|
|
41
47
|
raise ImportError(
|
|
42
48
|
'Please install `groq` to use the Groq model, '
|
|
43
|
-
"you can use the `groq` optional group — `pip install 'pydantic-ai[groq]'`"
|
|
49
|
+
"you can use the `groq` optional group — `pip install 'pydantic-ai-slim[groq]'`"
|
|
44
50
|
) from _import_error
|
|
45
51
|
|
|
46
52
|
GroqModelName = Literal[
|
|
53
|
+
'llama-3.3-70b-versatile',
|
|
47
54
|
'llama-3.1-70b-versatile',
|
|
48
55
|
'llama3-groq-70b-8192-tool-use-preview',
|
|
49
56
|
'llama3-groq-8b-8192-tool-use-preview',
|
|
@@ -149,28 +156,34 @@ class GroqAgentModel(AgentModel):
|
|
|
149
156
|
allow_text_result: bool
|
|
150
157
|
tools: list[chat.ChatCompletionToolParam]
|
|
151
158
|
|
|
152
|
-
async def request(
|
|
153
|
-
|
|
159
|
+
async def request(
|
|
160
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
|
+
) -> tuple[ModelResponse, result.Cost]:
|
|
162
|
+
response = await self._completions_create(messages, False, model_settings)
|
|
154
163
|
return self._process_response(response), _map_cost(response)
|
|
155
164
|
|
|
156
165
|
@asynccontextmanager
|
|
157
|
-
async def request_stream(
|
|
158
|
-
|
|
166
|
+
async def request_stream(
|
|
167
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
168
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
169
|
+
response = await self._completions_create(messages, True, model_settings)
|
|
159
170
|
async with response:
|
|
160
171
|
yield await self._process_streamed_response(response)
|
|
161
172
|
|
|
162
173
|
@overload
|
|
163
174
|
async def _completions_create(
|
|
164
|
-
self, messages: list[
|
|
175
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
|
|
165
176
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
166
177
|
pass
|
|
167
178
|
|
|
168
179
|
@overload
|
|
169
|
-
async def _completions_create(
|
|
180
|
+
async def _completions_create(
|
|
181
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
|
|
182
|
+
) -> chat.ChatCompletion:
|
|
170
183
|
pass
|
|
171
184
|
|
|
172
185
|
async def _completions_create(
|
|
173
|
-
self, messages: list[
|
|
186
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
|
|
174
187
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
175
188
|
# standalone function to make it easier to override
|
|
176
189
|
if not self.tools:
|
|
@@ -180,31 +193,36 @@ class GroqAgentModel(AgentModel):
|
|
|
180
193
|
else:
|
|
181
194
|
tool_choice = 'auto'
|
|
182
195
|
|
|
183
|
-
groq_messages =
|
|
196
|
+
groq_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
197
|
+
|
|
198
|
+
model_settings = model_settings or {}
|
|
199
|
+
|
|
184
200
|
return await self.client.chat.completions.create(
|
|
185
201
|
model=str(self.model_name),
|
|
186
202
|
messages=groq_messages,
|
|
187
|
-
temperature=0.0,
|
|
188
203
|
n=1,
|
|
189
204
|
parallel_tool_calls=True if self.tools else NOT_GIVEN,
|
|
190
205
|
tools=self.tools or NOT_GIVEN,
|
|
191
206
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
192
207
|
stream=stream,
|
|
208
|
+
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
209
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
210
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
211
|
+
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
193
212
|
)
|
|
194
213
|
|
|
195
214
|
@staticmethod
|
|
196
|
-
def _process_response(response: chat.ChatCompletion) ->
|
|
215
|
+
def _process_response(response: chat.ChatCompletion) -> ModelResponse:
|
|
197
216
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
198
217
|
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
199
218
|
choice = response.choices[0]
|
|
219
|
+
items: list[ModelResponsePart] = []
|
|
220
|
+
if choice.message.content is not None:
|
|
221
|
+
items.append(TextPart(choice.message.content))
|
|
200
222
|
if choice.message.tool_calls is not None:
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
)
|
|
205
|
-
else:
|
|
206
|
-
assert choice.message.content is not None, choice
|
|
207
|
-
return ModelTextResponse(choice.message.content, timestamp=timestamp)
|
|
223
|
+
for c in choice.message.tool_calls:
|
|
224
|
+
items.append(ToolCallPart.from_json(c.function.name, c.function.arguments, c.id))
|
|
225
|
+
return ModelResponse(items, timestamp=timestamp)
|
|
208
226
|
|
|
209
227
|
@staticmethod
|
|
210
228
|
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
|
@@ -233,47 +251,55 @@ class GroqAgentModel(AgentModel):
|
|
|
233
251
|
start_cost,
|
|
234
252
|
)
|
|
235
253
|
|
|
236
|
-
@
|
|
237
|
-
def _map_message(message:
|
|
254
|
+
@classmethod
|
|
255
|
+
def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
238
256
|
"""Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
|
|
239
|
-
if message
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
)
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
tool_call_id=_guard_tool_id(message),
|
|
260
|
-
content=message.model_response(),
|
|
261
|
-
)
|
|
262
|
-
elif message.role == 'model-text-response':
|
|
263
|
-
# ModelTextResponse ->
|
|
264
|
-
return chat.ChatCompletionAssistantMessageParam(role='assistant', content=message.content)
|
|
265
|
-
elif message.role == 'model-structured-response':
|
|
266
|
-
assert (
|
|
267
|
-
message.role == 'model-structured-response'
|
|
268
|
-
), f'Expected role to be "llm-tool-calls", got {message.role}'
|
|
269
|
-
# ModelStructuredResponse ->
|
|
270
|
-
return chat.ChatCompletionAssistantMessageParam(
|
|
271
|
-
role='assistant',
|
|
272
|
-
tool_calls=[_map_tool_call(t) for t in message.calls],
|
|
273
|
-
)
|
|
257
|
+
if isinstance(message, ModelRequest):
|
|
258
|
+
yield from cls._map_user_message(message)
|
|
259
|
+
elif isinstance(message, ModelResponse):
|
|
260
|
+
texts: list[str] = []
|
|
261
|
+
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
|
|
262
|
+
for item in message.parts:
|
|
263
|
+
if isinstance(item, TextPart):
|
|
264
|
+
texts.append(item.content)
|
|
265
|
+
elif isinstance(item, ToolCallPart):
|
|
266
|
+
tool_calls.append(_map_tool_call(item))
|
|
267
|
+
else:
|
|
268
|
+
assert_never(item)
|
|
269
|
+
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
|
|
270
|
+
if texts:
|
|
271
|
+
# Note: model responses from this model should only have one text item, so the following
|
|
272
|
+
# shouldn't merge multiple texts into one unless you switch models between runs:
|
|
273
|
+
message_param['content'] = '\n\n'.join(texts)
|
|
274
|
+
if tool_calls:
|
|
275
|
+
message_param['tool_calls'] = tool_calls
|
|
276
|
+
yield message_param
|
|
274
277
|
else:
|
|
275
278
|
assert_never(message)
|
|
276
279
|
|
|
280
|
+
@classmethod
|
|
281
|
+
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
282
|
+
for part in message.parts:
|
|
283
|
+
if isinstance(part, SystemPromptPart):
|
|
284
|
+
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
|
|
285
|
+
elif isinstance(part, UserPromptPart):
|
|
286
|
+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
|
|
287
|
+
elif isinstance(part, ToolReturnPart):
|
|
288
|
+
yield chat.ChatCompletionToolMessageParam(
|
|
289
|
+
role='tool',
|
|
290
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
|
|
291
|
+
content=part.model_response_str(),
|
|
292
|
+
)
|
|
293
|
+
elif isinstance(part, RetryPromptPart):
|
|
294
|
+
if part.tool_name is None:
|
|
295
|
+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
|
|
296
|
+
else:
|
|
297
|
+
yield chat.ChatCompletionToolMessageParam(
|
|
298
|
+
role='tool',
|
|
299
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
|
|
300
|
+
content=part.model_response(),
|
|
301
|
+
)
|
|
302
|
+
|
|
277
303
|
|
|
278
304
|
@dataclass
|
|
279
305
|
class GroqStreamTextResponse(StreamTextResponse):
|
|
@@ -349,14 +375,14 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
|
|
|
349
375
|
else:
|
|
350
376
|
self._delta_tool_calls[new.index] = new
|
|
351
377
|
|
|
352
|
-
def get(self, *, final: bool = False) ->
|
|
353
|
-
|
|
378
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
379
|
+
items: list[ModelResponsePart] = []
|
|
354
380
|
for c in self._delta_tool_calls.values():
|
|
355
381
|
if f := c.function:
|
|
356
382
|
if f.name is not None and f.arguments is not None:
|
|
357
|
-
|
|
383
|
+
items.append(ToolCallPart.from_json(f.name, f.arguments, c.id))
|
|
358
384
|
|
|
359
|
-
return
|
|
385
|
+
return ModelResponse(items, timestamp=self._timestamp)
|
|
360
386
|
|
|
361
387
|
def cost(self) -> Cost:
|
|
362
388
|
return self._cost
|
|
@@ -365,16 +391,10 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
|
|
|
365
391
|
return self._timestamp
|
|
366
392
|
|
|
367
393
|
|
|
368
|
-
def
|
|
369
|
-
"""Type guard that checks a `tool_id` is not None both for static typing and runtime."""
|
|
370
|
-
assert t.tool_id is not None, f'Groq requires `tool_id` to be set: {t}'
|
|
371
|
-
return t.tool_id
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
def _map_tool_call(t: ToolCall) -> chat.ChatCompletionMessageToolCallParam:
|
|
394
|
+
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
375
395
|
assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
|
|
376
396
|
return chat.ChatCompletionMessageToolCallParam(
|
|
377
|
-
id=
|
|
397
|
+
id=_guard_tool_call_id(t=t, model_source='Groq'),
|
|
378
398
|
type='function',
|
|
379
399
|
function={'name': t.tool_name, 'arguments': t.args.args_json},
|
|
380
400
|
)
|