pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +12 -2
- pydantic_ai/_pydantic.py +7 -25
- pydantic_ai/_result.py +33 -18
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_utils.py +9 -2
- pydantic_ai/agent.py +366 -171
- pydantic_ai/exceptions.py +20 -2
- pydantic_ai/messages.py +111 -50
- pydantic_ai/models/__init__.py +39 -14
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +62 -40
- pydantic_ai/models/gemini.py +164 -124
- pydantic_ai/models/groq.py +112 -94
- pydantic_ai/models/mistral.py +668 -0
- pydantic_ai/models/ollama.py +1 -1
- pydantic_ai/models/openai.py +120 -96
- pydantic_ai/models/test.py +78 -61
- pydantic_ai/models/vertexai.py +7 -3
- pydantic_ai/result.py +96 -68
- pydantic_ai/settings.py +137 -0
- pydantic_ai/tools.py +46 -26
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.14.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +0 -23
pydantic_ai/models/groq.py
CHANGED
|
@@ -4,23 +4,28 @@ from collections.abc import AsyncIterator, Iterable
|
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
7
|
+
from itertools import chain
|
|
7
8
|
from typing import Literal, overload
|
|
8
9
|
|
|
9
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
10
11
|
from typing_extensions import assert_never
|
|
11
12
|
|
|
12
13
|
from .. import UnexpectedModelBehavior, _utils, result
|
|
14
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
15
|
from ..messages import (
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
16
|
+
ModelMessage,
|
|
17
|
+
ModelRequest,
|
|
18
|
+
ModelResponse,
|
|
19
|
+
ModelResponsePart,
|
|
20
|
+
RetryPromptPart,
|
|
21
|
+
SystemPromptPart,
|
|
22
|
+
TextPart,
|
|
23
|
+
ToolCallPart,
|
|
24
|
+
ToolReturnPart,
|
|
25
|
+
UserPromptPart,
|
|
22
26
|
)
|
|
23
|
-
from ..result import
|
|
27
|
+
from ..result import Usage
|
|
28
|
+
from ..settings import ModelSettings
|
|
24
29
|
from ..tools import ToolDefinition
|
|
25
30
|
from . import (
|
|
26
31
|
AgentModel,
|
|
@@ -40,10 +45,11 @@ try:
|
|
|
40
45
|
except ImportError as _import_error:
|
|
41
46
|
raise ImportError(
|
|
42
47
|
'Please install `groq` to use the Groq model, '
|
|
43
|
-
"you can use the `groq` optional group — `pip install 'pydantic-ai[groq]'`"
|
|
48
|
+
"you can use the `groq` optional group — `pip install 'pydantic-ai-slim[groq]'`"
|
|
44
49
|
) from _import_error
|
|
45
50
|
|
|
46
51
|
GroqModelName = Literal[
|
|
52
|
+
'llama-3.3-70b-versatile',
|
|
47
53
|
'llama-3.1-70b-versatile',
|
|
48
54
|
'llama3-groq-70b-8192-tool-use-preview',
|
|
49
55
|
'llama3-groq-8b-8192-tool-use-preview',
|
|
@@ -149,28 +155,34 @@ class GroqAgentModel(AgentModel):
|
|
|
149
155
|
allow_text_result: bool
|
|
150
156
|
tools: list[chat.ChatCompletionToolParam]
|
|
151
157
|
|
|
152
|
-
async def request(
|
|
153
|
-
|
|
154
|
-
|
|
158
|
+
async def request(
|
|
159
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
160
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
161
|
+
response = await self._completions_create(messages, False, model_settings)
|
|
162
|
+
return self._process_response(response), _map_usage(response)
|
|
155
163
|
|
|
156
164
|
@asynccontextmanager
|
|
157
|
-
async def request_stream(
|
|
158
|
-
|
|
165
|
+
async def request_stream(
|
|
166
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
167
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
168
|
+
response = await self._completions_create(messages, True, model_settings)
|
|
159
169
|
async with response:
|
|
160
170
|
yield await self._process_streamed_response(response)
|
|
161
171
|
|
|
162
172
|
@overload
|
|
163
173
|
async def _completions_create(
|
|
164
|
-
self, messages: list[
|
|
174
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
|
|
165
175
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
166
176
|
pass
|
|
167
177
|
|
|
168
178
|
@overload
|
|
169
|
-
async def _completions_create(
|
|
179
|
+
async def _completions_create(
|
|
180
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
|
|
181
|
+
) -> chat.ChatCompletion:
|
|
170
182
|
pass
|
|
171
183
|
|
|
172
184
|
async def _completions_create(
|
|
173
|
-
self, messages: list[
|
|
185
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
|
|
174
186
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
175
187
|
# standalone function to make it easier to override
|
|
176
188
|
if not self.tools:
|
|
@@ -180,37 +192,42 @@ class GroqAgentModel(AgentModel):
|
|
|
180
192
|
else:
|
|
181
193
|
tool_choice = 'auto'
|
|
182
194
|
|
|
183
|
-
groq_messages =
|
|
195
|
+
groq_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
196
|
+
|
|
197
|
+
model_settings = model_settings or {}
|
|
198
|
+
|
|
184
199
|
return await self.client.chat.completions.create(
|
|
185
200
|
model=str(self.model_name),
|
|
186
201
|
messages=groq_messages,
|
|
187
|
-
temperature=0.0,
|
|
188
202
|
n=1,
|
|
189
203
|
parallel_tool_calls=True if self.tools else NOT_GIVEN,
|
|
190
204
|
tools=self.tools or NOT_GIVEN,
|
|
191
205
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
192
206
|
stream=stream,
|
|
207
|
+
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
208
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
209
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
210
|
+
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
193
211
|
)
|
|
194
212
|
|
|
195
213
|
@staticmethod
|
|
196
|
-
def _process_response(response: chat.ChatCompletion) ->
|
|
214
|
+
def _process_response(response: chat.ChatCompletion) -> ModelResponse:
|
|
197
215
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
198
216
|
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
199
217
|
choice = response.choices[0]
|
|
218
|
+
items: list[ModelResponsePart] = []
|
|
219
|
+
if choice.message.content is not None:
|
|
220
|
+
items.append(TextPart(choice.message.content))
|
|
200
221
|
if choice.message.tool_calls is not None:
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
)
|
|
205
|
-
else:
|
|
206
|
-
assert choice.message.content is not None, choice
|
|
207
|
-
return ModelTextResponse(choice.message.content, timestamp=timestamp)
|
|
222
|
+
for c in choice.message.tool_calls:
|
|
223
|
+
items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
|
|
224
|
+
return ModelResponse(items, timestamp=timestamp)
|
|
208
225
|
|
|
209
226
|
@staticmethod
|
|
210
227
|
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
|
211
228
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
212
229
|
timestamp: datetime | None = None
|
|
213
|
-
|
|
230
|
+
start_usage = Usage()
|
|
214
231
|
# the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
|
|
215
232
|
while True:
|
|
216
233
|
try:
|
|
@@ -218,62 +235,70 @@ class GroqAgentModel(AgentModel):
|
|
|
218
235
|
except StopAsyncIteration as e:
|
|
219
236
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
220
237
|
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
221
|
-
|
|
238
|
+
start_usage += _map_usage(chunk)
|
|
222
239
|
|
|
223
240
|
if chunk.choices:
|
|
224
241
|
delta = chunk.choices[0].delta
|
|
225
242
|
|
|
226
243
|
if delta.content is not None:
|
|
227
|
-
return GroqStreamTextResponse(delta.content, response, timestamp,
|
|
244
|
+
return GroqStreamTextResponse(delta.content, response, timestamp, start_usage)
|
|
228
245
|
elif delta.tool_calls is not None:
|
|
229
246
|
return GroqStreamStructuredResponse(
|
|
230
247
|
response,
|
|
231
248
|
{c.index: c for c in delta.tool_calls},
|
|
232
249
|
timestamp,
|
|
233
|
-
|
|
250
|
+
start_usage,
|
|
234
251
|
)
|
|
235
252
|
|
|
236
|
-
@
|
|
237
|
-
def _map_message(message:
|
|
253
|
+
@classmethod
|
|
254
|
+
def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
238
255
|
"""Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
|
|
239
|
-
if message
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
)
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
tool_call_id=_guard_tool_id(message),
|
|
260
|
-
content=message.model_response(),
|
|
261
|
-
)
|
|
262
|
-
elif message.role == 'model-text-response':
|
|
263
|
-
# ModelTextResponse ->
|
|
264
|
-
return chat.ChatCompletionAssistantMessageParam(role='assistant', content=message.content)
|
|
265
|
-
elif message.role == 'model-structured-response':
|
|
266
|
-
assert (
|
|
267
|
-
message.role == 'model-structured-response'
|
|
268
|
-
), f'Expected role to be "llm-tool-calls", got {message.role}'
|
|
269
|
-
# ModelStructuredResponse ->
|
|
270
|
-
return chat.ChatCompletionAssistantMessageParam(
|
|
271
|
-
role='assistant',
|
|
272
|
-
tool_calls=[_map_tool_call(t) for t in message.calls],
|
|
273
|
-
)
|
|
256
|
+
if isinstance(message, ModelRequest):
|
|
257
|
+
yield from cls._map_user_message(message)
|
|
258
|
+
elif isinstance(message, ModelResponse):
|
|
259
|
+
texts: list[str] = []
|
|
260
|
+
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
|
|
261
|
+
for item in message.parts:
|
|
262
|
+
if isinstance(item, TextPart):
|
|
263
|
+
texts.append(item.content)
|
|
264
|
+
elif isinstance(item, ToolCallPart):
|
|
265
|
+
tool_calls.append(_map_tool_call(item))
|
|
266
|
+
else:
|
|
267
|
+
assert_never(item)
|
|
268
|
+
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
|
|
269
|
+
if texts:
|
|
270
|
+
# Note: model responses from this model should only have one text item, so the following
|
|
271
|
+
# shouldn't merge multiple texts into one unless you switch models between runs:
|
|
272
|
+
message_param['content'] = '\n\n'.join(texts)
|
|
273
|
+
if tool_calls:
|
|
274
|
+
message_param['tool_calls'] = tool_calls
|
|
275
|
+
yield message_param
|
|
274
276
|
else:
|
|
275
277
|
assert_never(message)
|
|
276
278
|
|
|
279
|
+
@classmethod
|
|
280
|
+
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
281
|
+
for part in message.parts:
|
|
282
|
+
if isinstance(part, SystemPromptPart):
|
|
283
|
+
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
|
|
284
|
+
elif isinstance(part, UserPromptPart):
|
|
285
|
+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
|
|
286
|
+
elif isinstance(part, ToolReturnPart):
|
|
287
|
+
yield chat.ChatCompletionToolMessageParam(
|
|
288
|
+
role='tool',
|
|
289
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
|
|
290
|
+
content=part.model_response_str(),
|
|
291
|
+
)
|
|
292
|
+
elif isinstance(part, RetryPromptPart):
|
|
293
|
+
if part.tool_name is None:
|
|
294
|
+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
|
|
295
|
+
else:
|
|
296
|
+
yield chat.ChatCompletionToolMessageParam(
|
|
297
|
+
role='tool',
|
|
298
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
|
|
299
|
+
content=part.model_response(),
|
|
300
|
+
)
|
|
301
|
+
|
|
277
302
|
|
|
278
303
|
@dataclass
|
|
279
304
|
class GroqStreamTextResponse(StreamTextResponse):
|
|
@@ -282,7 +307,7 @@ class GroqStreamTextResponse(StreamTextResponse):
|
|
|
282
307
|
_first: str | None
|
|
283
308
|
_response: AsyncStream[ChatCompletionChunk]
|
|
284
309
|
_timestamp: datetime
|
|
285
|
-
|
|
310
|
+
_usage: result.Usage
|
|
286
311
|
_buffer: list[str] = field(default_factory=list, init=False)
|
|
287
312
|
|
|
288
313
|
async def __anext__(self) -> None:
|
|
@@ -292,7 +317,7 @@ class GroqStreamTextResponse(StreamTextResponse):
|
|
|
292
317
|
return None
|
|
293
318
|
|
|
294
319
|
chunk = await self._response.__anext__()
|
|
295
|
-
self.
|
|
320
|
+
self._usage = _map_usage(chunk)
|
|
296
321
|
|
|
297
322
|
try:
|
|
298
323
|
choice = chunk.choices[0]
|
|
@@ -309,8 +334,8 @@ class GroqStreamTextResponse(StreamTextResponse):
|
|
|
309
334
|
yield from self._buffer
|
|
310
335
|
self._buffer.clear()
|
|
311
336
|
|
|
312
|
-
def
|
|
313
|
-
return self.
|
|
337
|
+
def usage(self) -> Usage:
|
|
338
|
+
return self._usage
|
|
314
339
|
|
|
315
340
|
def timestamp(self) -> datetime:
|
|
316
341
|
return self._timestamp
|
|
@@ -323,11 +348,11 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
|
|
|
323
348
|
_response: AsyncStream[ChatCompletionChunk]
|
|
324
349
|
_delta_tool_calls: dict[int, ChoiceDeltaToolCall]
|
|
325
350
|
_timestamp: datetime
|
|
326
|
-
|
|
351
|
+
_usage: result.Usage
|
|
327
352
|
|
|
328
353
|
async def __anext__(self) -> None:
|
|
329
354
|
chunk = await self._response.__anext__()
|
|
330
|
-
self.
|
|
355
|
+
self._usage = _map_usage(chunk)
|
|
331
356
|
|
|
332
357
|
try:
|
|
333
358
|
choice = chunk.choices[0]
|
|
@@ -349,38 +374,31 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
|
|
|
349
374
|
else:
|
|
350
375
|
self._delta_tool_calls[new.index] = new
|
|
351
376
|
|
|
352
|
-
def get(self, *, final: bool = False) ->
|
|
353
|
-
|
|
377
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
378
|
+
items: list[ModelResponsePart] = []
|
|
354
379
|
for c in self._delta_tool_calls.values():
|
|
355
380
|
if f := c.function:
|
|
356
381
|
if f.name is not None and f.arguments is not None:
|
|
357
|
-
|
|
382
|
+
items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
|
|
358
383
|
|
|
359
|
-
return
|
|
384
|
+
return ModelResponse(items, timestamp=self._timestamp)
|
|
360
385
|
|
|
361
|
-
def
|
|
362
|
-
return self.
|
|
386
|
+
def usage(self) -> Usage:
|
|
387
|
+
return self._usage
|
|
363
388
|
|
|
364
389
|
def timestamp(self) -> datetime:
|
|
365
390
|
return self._timestamp
|
|
366
391
|
|
|
367
392
|
|
|
368
|
-
def
|
|
369
|
-
"""Type guard that checks a `tool_id` is not None both for static typing and runtime."""
|
|
370
|
-
assert t.tool_id is not None, f'Groq requires `tool_id` to be set: {t}'
|
|
371
|
-
return t.tool_id
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
def _map_tool_call(t: ToolCall) -> chat.ChatCompletionMessageToolCallParam:
|
|
375
|
-
assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
|
|
393
|
+
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
376
394
|
return chat.ChatCompletionMessageToolCallParam(
|
|
377
|
-
id=
|
|
395
|
+
id=_guard_tool_call_id(t=t, model_source='Groq'),
|
|
378
396
|
type='function',
|
|
379
|
-
function={'name': t.tool_name, 'arguments': t.
|
|
397
|
+
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
380
398
|
)
|
|
381
399
|
|
|
382
400
|
|
|
383
|
-
def
|
|
401
|
+
def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> result.Usage:
|
|
384
402
|
usage = None
|
|
385
403
|
if isinstance(completion, ChatCompletion):
|
|
386
404
|
usage = completion.usage
|
|
@@ -388,9 +406,9 @@ def _map_cost(completion: ChatCompletionChunk | ChatCompletion) -> result.Cost:
|
|
|
388
406
|
usage = completion.x_groq.usage
|
|
389
407
|
|
|
390
408
|
if usage is None:
|
|
391
|
-
return result.
|
|
409
|
+
return result.Usage()
|
|
392
410
|
|
|
393
|
-
return result.
|
|
411
|
+
return result.Usage(
|
|
394
412
|
request_tokens=usage.prompt_tokens,
|
|
395
413
|
response_tokens=usage.completion_tokens,
|
|
396
414
|
total_tokens=usage.total_tokens,
|