pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -4,23 +4,28 @@ from collections.abc import AsyncIterator, Iterable
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
7
+ from itertools import chain
7
8
  from typing import Literal, overload
8
9
 
9
10
  from httpx import AsyncClient as AsyncHTTPClient
10
11
  from typing_extensions import assert_never
11
12
 
12
13
  from .. import UnexpectedModelBehavior, _utils, result
14
+ from .._utils import guard_tool_call_id as _guard_tool_call_id
13
15
  from ..messages import (
14
- ArgsJson,
15
- Message,
16
- ModelAnyResponse,
17
- ModelStructuredResponse,
18
- ModelTextResponse,
19
- RetryPrompt,
20
- ToolCall,
21
- ToolReturn,
16
+ ModelMessage,
17
+ ModelRequest,
18
+ ModelResponse,
19
+ ModelResponsePart,
20
+ RetryPromptPart,
21
+ SystemPromptPart,
22
+ TextPart,
23
+ ToolCallPart,
24
+ ToolReturnPart,
25
+ UserPromptPart,
22
26
  )
23
- from ..result import Cost
27
+ from ..result import Usage
28
+ from ..settings import ModelSettings
24
29
  from ..tools import ToolDefinition
25
30
  from . import (
26
31
  AgentModel,
@@ -40,10 +45,11 @@ try:
40
45
  except ImportError as _import_error:
41
46
  raise ImportError(
42
47
  'Please install `groq` to use the Groq model, '
43
- "you can use the `groq` optional group — `pip install 'pydantic-ai[groq]'`"
48
+ "you can use the `groq` optional group — `pip install 'pydantic-ai-slim[groq]'`"
44
49
  ) from _import_error
45
50
 
46
51
  GroqModelName = Literal[
52
+ 'llama-3.3-70b-versatile',
47
53
  'llama-3.1-70b-versatile',
48
54
  'llama3-groq-70b-8192-tool-use-preview',
49
55
  'llama3-groq-8b-8192-tool-use-preview',
@@ -149,28 +155,34 @@ class GroqAgentModel(AgentModel):
149
155
  allow_text_result: bool
150
156
  tools: list[chat.ChatCompletionToolParam]
151
157
 
152
- async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
153
- response = await self._completions_create(messages, False)
154
- return self._process_response(response), _map_cost(response)
158
+ async def request(
159
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
160
+ ) -> tuple[ModelResponse, result.Usage]:
161
+ response = await self._completions_create(messages, False, model_settings)
162
+ return self._process_response(response), _map_usage(response)
155
163
 
156
164
  @asynccontextmanager
157
- async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
158
- response = await self._completions_create(messages, True)
165
+ async def request_stream(
166
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
167
+ ) -> AsyncIterator[EitherStreamedResponse]:
168
+ response = await self._completions_create(messages, True, model_settings)
159
169
  async with response:
160
170
  yield await self._process_streamed_response(response)
161
171
 
162
172
  @overload
163
173
  async def _completions_create(
164
- self, messages: list[Message], stream: Literal[True]
174
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
165
175
  ) -> AsyncStream[ChatCompletionChunk]:
166
176
  pass
167
177
 
168
178
  @overload
169
- async def _completions_create(self, messages: list[Message], stream: Literal[False]) -> chat.ChatCompletion:
179
+ async def _completions_create(
180
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
181
+ ) -> chat.ChatCompletion:
170
182
  pass
171
183
 
172
184
  async def _completions_create(
173
- self, messages: list[Message], stream: bool
185
+ self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
174
186
  ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
175
187
  # standalone function to make it easier to override
176
188
  if not self.tools:
@@ -180,37 +192,42 @@ class GroqAgentModel(AgentModel):
180
192
  else:
181
193
  tool_choice = 'auto'
182
194
 
183
- groq_messages = [self._map_message(m) for m in messages]
195
+ groq_messages = list(chain(*(self._map_message(m) for m in messages)))
196
+
197
+ model_settings = model_settings or {}
198
+
184
199
  return await self.client.chat.completions.create(
185
200
  model=str(self.model_name),
186
201
  messages=groq_messages,
187
- temperature=0.0,
188
202
  n=1,
189
203
  parallel_tool_calls=True if self.tools else NOT_GIVEN,
190
204
  tools=self.tools or NOT_GIVEN,
191
205
  tool_choice=tool_choice or NOT_GIVEN,
192
206
  stream=stream,
207
+ max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
208
+ temperature=model_settings.get('temperature', NOT_GIVEN),
209
+ top_p=model_settings.get('top_p', NOT_GIVEN),
210
+ timeout=model_settings.get('timeout', NOT_GIVEN),
193
211
  )
194
212
 
195
213
  @staticmethod
196
- def _process_response(response: chat.ChatCompletion) -> ModelAnyResponse:
214
+ def _process_response(response: chat.ChatCompletion) -> ModelResponse:
197
215
  """Process a non-streamed response, and prepare a message to return."""
198
216
  timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
199
217
  choice = response.choices[0]
218
+ items: list[ModelResponsePart] = []
219
+ if choice.message.content is not None:
220
+ items.append(TextPart(choice.message.content))
200
221
  if choice.message.tool_calls is not None:
201
- return ModelStructuredResponse(
202
- [ToolCall.from_json(c.function.name, c.function.arguments, c.id) for c in choice.message.tool_calls],
203
- timestamp=timestamp,
204
- )
205
- else:
206
- assert choice.message.content is not None, choice
207
- return ModelTextResponse(choice.message.content, timestamp=timestamp)
222
+ for c in choice.message.tool_calls:
223
+ items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
224
+ return ModelResponse(items, timestamp=timestamp)
208
225
 
209
226
  @staticmethod
210
227
  async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
211
228
  """Process a streamed response, and prepare a streaming response to return."""
212
229
  timestamp: datetime | None = None
213
- start_cost = Cost()
230
+ start_usage = Usage()
214
231
  # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
215
232
  while True:
216
233
  try:
@@ -218,62 +235,70 @@ class GroqAgentModel(AgentModel):
218
235
  except StopAsyncIteration as e:
219
236
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
220
237
  timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
221
- start_cost += _map_cost(chunk)
238
+ start_usage += _map_usage(chunk)
222
239
 
223
240
  if chunk.choices:
224
241
  delta = chunk.choices[0].delta
225
242
 
226
243
  if delta.content is not None:
227
- return GroqStreamTextResponse(delta.content, response, timestamp, start_cost)
244
+ return GroqStreamTextResponse(delta.content, response, timestamp, start_usage)
228
245
  elif delta.tool_calls is not None:
229
246
  return GroqStreamStructuredResponse(
230
247
  response,
231
248
  {c.index: c for c in delta.tool_calls},
232
249
  timestamp,
233
- start_cost,
250
+ start_usage,
234
251
  )
235
252
 
236
- @staticmethod
237
- def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
253
+ @classmethod
254
+ def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
238
255
  """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
239
- if message.role == 'system':
240
- # SystemPrompt ->
241
- return chat.ChatCompletionSystemMessageParam(role='system', content=message.content)
242
- elif message.role == 'user':
243
- # UserPrompt ->
244
- return chat.ChatCompletionUserMessageParam(role='user', content=message.content)
245
- elif message.role == 'tool-return':
246
- # ToolReturn ->
247
- return chat.ChatCompletionToolMessageParam(
248
- role='tool',
249
- tool_call_id=_guard_tool_id(message),
250
- content=message.model_response_str(),
251
- )
252
- elif message.role == 'retry-prompt':
253
- # RetryPrompt ->
254
- if message.tool_name is None:
255
- return chat.ChatCompletionUserMessageParam(role='user', content=message.model_response())
256
- else:
257
- return chat.ChatCompletionToolMessageParam(
258
- role='tool',
259
- tool_call_id=_guard_tool_id(message),
260
- content=message.model_response(),
261
- )
262
- elif message.role == 'model-text-response':
263
- # ModelTextResponse ->
264
- return chat.ChatCompletionAssistantMessageParam(role='assistant', content=message.content)
265
- elif message.role == 'model-structured-response':
266
- assert (
267
- message.role == 'model-structured-response'
268
- ), f'Expected role to be "llm-tool-calls", got {message.role}'
269
- # ModelStructuredResponse ->
270
- return chat.ChatCompletionAssistantMessageParam(
271
- role='assistant',
272
- tool_calls=[_map_tool_call(t) for t in message.calls],
273
- )
256
+ if isinstance(message, ModelRequest):
257
+ yield from cls._map_user_message(message)
258
+ elif isinstance(message, ModelResponse):
259
+ texts: list[str] = []
260
+ tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
261
+ for item in message.parts:
262
+ if isinstance(item, TextPart):
263
+ texts.append(item.content)
264
+ elif isinstance(item, ToolCallPart):
265
+ tool_calls.append(_map_tool_call(item))
266
+ else:
267
+ assert_never(item)
268
+ message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
269
+ if texts:
270
+ # Note: model responses from this model should only have one text item, so the following
271
+ # shouldn't merge multiple texts into one unless you switch models between runs:
272
+ message_param['content'] = '\n\n'.join(texts)
273
+ if tool_calls:
274
+ message_param['tool_calls'] = tool_calls
275
+ yield message_param
274
276
  else:
275
277
  assert_never(message)
276
278
 
279
+ @classmethod
280
+ def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
281
+ for part in message.parts:
282
+ if isinstance(part, SystemPromptPart):
283
+ yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
284
+ elif isinstance(part, UserPromptPart):
285
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
286
+ elif isinstance(part, ToolReturnPart):
287
+ yield chat.ChatCompletionToolMessageParam(
288
+ role='tool',
289
+ tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
290
+ content=part.model_response_str(),
291
+ )
292
+ elif isinstance(part, RetryPromptPart):
293
+ if part.tool_name is None:
294
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
295
+ else:
296
+ yield chat.ChatCompletionToolMessageParam(
297
+ role='tool',
298
+ tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
299
+ content=part.model_response(),
300
+ )
301
+
277
302
 
278
303
  @dataclass
279
304
  class GroqStreamTextResponse(StreamTextResponse):
@@ -282,7 +307,7 @@ class GroqStreamTextResponse(StreamTextResponse):
282
307
  _first: str | None
283
308
  _response: AsyncStream[ChatCompletionChunk]
284
309
  _timestamp: datetime
285
- _cost: result.Cost
310
+ _usage: result.Usage
286
311
  _buffer: list[str] = field(default_factory=list, init=False)
287
312
 
288
313
  async def __anext__(self) -> None:
@@ -292,7 +317,7 @@ class GroqStreamTextResponse(StreamTextResponse):
292
317
  return None
293
318
 
294
319
  chunk = await self._response.__anext__()
295
- self._cost = _map_cost(chunk)
320
+ self._usage = _map_usage(chunk)
296
321
 
297
322
  try:
298
323
  choice = chunk.choices[0]
@@ -309,8 +334,8 @@ class GroqStreamTextResponse(StreamTextResponse):
309
334
  yield from self._buffer
310
335
  self._buffer.clear()
311
336
 
312
- def cost(self) -> Cost:
313
- return self._cost
337
+ def usage(self) -> Usage:
338
+ return self._usage
314
339
 
315
340
  def timestamp(self) -> datetime:
316
341
  return self._timestamp
@@ -323,11 +348,11 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
323
348
  _response: AsyncStream[ChatCompletionChunk]
324
349
  _delta_tool_calls: dict[int, ChoiceDeltaToolCall]
325
350
  _timestamp: datetime
326
- _cost: result.Cost
351
+ _usage: result.Usage
327
352
 
328
353
  async def __anext__(self) -> None:
329
354
  chunk = await self._response.__anext__()
330
- self._cost = _map_cost(chunk)
355
+ self._usage = _map_usage(chunk)
331
356
 
332
357
  try:
333
358
  choice = chunk.choices[0]
@@ -349,38 +374,31 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
349
374
  else:
350
375
  self._delta_tool_calls[new.index] = new
351
376
 
352
- def get(self, *, final: bool = False) -> ModelStructuredResponse:
353
- calls: list[ToolCall] = []
377
+ def get(self, *, final: bool = False) -> ModelResponse:
378
+ items: list[ModelResponsePart] = []
354
379
  for c in self._delta_tool_calls.values():
355
380
  if f := c.function:
356
381
  if f.name is not None and f.arguments is not None:
357
- calls.append(ToolCall.from_json(f.name, f.arguments, c.id))
382
+ items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
358
383
 
359
- return ModelStructuredResponse(calls, timestamp=self._timestamp)
384
+ return ModelResponse(items, timestamp=self._timestamp)
360
385
 
361
- def cost(self) -> Cost:
362
- return self._cost
386
+ def usage(self) -> Usage:
387
+ return self._usage
363
388
 
364
389
  def timestamp(self) -> datetime:
365
390
  return self._timestamp
366
391
 
367
392
 
368
- def _guard_tool_id(t: ToolCall | ToolReturn | RetryPrompt) -> str:
369
- """Type guard that checks a `tool_id` is not None both for static typing and runtime."""
370
- assert t.tool_id is not None, f'Groq requires `tool_id` to be set: {t}'
371
- return t.tool_id
372
-
373
-
374
- def _map_tool_call(t: ToolCall) -> chat.ChatCompletionMessageToolCallParam:
375
- assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
393
+ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
376
394
  return chat.ChatCompletionMessageToolCallParam(
377
- id=_guard_tool_id(t),
395
+ id=_guard_tool_call_id(t=t, model_source='Groq'),
378
396
  type='function',
379
- function={'name': t.tool_name, 'arguments': t.args.args_json},
397
+ function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
380
398
  )
381
399
 
382
400
 
383
- def _map_cost(completion: ChatCompletionChunk | ChatCompletion) -> result.Cost:
401
+ def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> result.Usage:
384
402
  usage = None
385
403
  if isinstance(completion, ChatCompletion):
386
404
  usage = completion.usage
@@ -388,9 +406,9 @@ def _map_cost(completion: ChatCompletionChunk | ChatCompletion) -> result.Cost:
388
406
  usage = completion.x_groq.usage
389
407
 
390
408
  if usage is None:
391
- return result.Cost()
409
+ return result.Usage()
392
410
 
393
- return result.Cost(
411
+ return result.Usage(
394
412
  request_tokens=usage.prompt_tokens,
395
413
  response_tokens=usage.completion_tokens,
396
414
  total_tokens=usage.total_tokens,