pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -4,23 +4,29 @@ from collections.abc import AsyncIterator, Iterable
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
7
+ from itertools import chain
7
8
  from typing import Literal, overload
8
9
 
9
10
  from httpx import AsyncClient as AsyncHTTPClient
10
11
  from typing_extensions import assert_never
11
12
 
12
13
  from .. import UnexpectedModelBehavior, _utils, result
14
+ from .._utils import guard_tool_call_id as _guard_tool_call_id
13
15
  from ..messages import (
14
16
  ArgsJson,
15
- Message,
16
- ModelAnyResponse,
17
- ModelStructuredResponse,
18
- ModelTextResponse,
19
- RetryPrompt,
20
- ToolCall,
21
- ToolReturn,
17
+ ModelMessage,
18
+ ModelRequest,
19
+ ModelResponse,
20
+ ModelResponsePart,
21
+ RetryPromptPart,
22
+ SystemPromptPart,
23
+ TextPart,
24
+ ToolCallPart,
25
+ ToolReturnPart,
26
+ UserPromptPart,
22
27
  )
23
28
  from ..result import Cost
29
+ from ..settings import ModelSettings
24
30
  from ..tools import ToolDefinition
25
31
  from . import (
26
32
  AgentModel,
@@ -40,10 +46,11 @@ try:
40
46
  except ImportError as _import_error:
41
47
  raise ImportError(
42
48
  'Please install `groq` to use the Groq model, '
43
- "you can use the `groq` optional group — `pip install 'pydantic-ai[groq]'`"
49
+ "you can use the `groq` optional group — `pip install 'pydantic-ai-slim[groq]'`"
44
50
  ) from _import_error
45
51
 
46
52
  GroqModelName = Literal[
53
+ 'llama-3.3-70b-versatile',
47
54
  'llama-3.1-70b-versatile',
48
55
  'llama3-groq-70b-8192-tool-use-preview',
49
56
  'llama3-groq-8b-8192-tool-use-preview',
@@ -149,28 +156,34 @@ class GroqAgentModel(AgentModel):
149
156
  allow_text_result: bool
150
157
  tools: list[chat.ChatCompletionToolParam]
151
158
 
152
- async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
153
- response = await self._completions_create(messages, False)
159
+ async def request(
160
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
161
+ ) -> tuple[ModelResponse, result.Cost]:
162
+ response = await self._completions_create(messages, False, model_settings)
154
163
  return self._process_response(response), _map_cost(response)
155
164
 
156
165
  @asynccontextmanager
157
- async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
158
- response = await self._completions_create(messages, True)
166
+ async def request_stream(
167
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
168
+ ) -> AsyncIterator[EitherStreamedResponse]:
169
+ response = await self._completions_create(messages, True, model_settings)
159
170
  async with response:
160
171
  yield await self._process_streamed_response(response)
161
172
 
162
173
  @overload
163
174
  async def _completions_create(
164
- self, messages: list[Message], stream: Literal[True]
175
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
165
176
  ) -> AsyncStream[ChatCompletionChunk]:
166
177
  pass
167
178
 
168
179
  @overload
169
- async def _completions_create(self, messages: list[Message], stream: Literal[False]) -> chat.ChatCompletion:
180
+ async def _completions_create(
181
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
182
+ ) -> chat.ChatCompletion:
170
183
  pass
171
184
 
172
185
  async def _completions_create(
173
- self, messages: list[Message], stream: bool
186
+ self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
174
187
  ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
175
188
  # standalone function to make it easier to override
176
189
  if not self.tools:
@@ -180,31 +193,36 @@ class GroqAgentModel(AgentModel):
180
193
  else:
181
194
  tool_choice = 'auto'
182
195
 
183
- groq_messages = [self._map_message(m) for m in messages]
196
+ groq_messages = list(chain(*(self._map_message(m) for m in messages)))
197
+
198
+ model_settings = model_settings or {}
199
+
184
200
  return await self.client.chat.completions.create(
185
201
  model=str(self.model_name),
186
202
  messages=groq_messages,
187
- temperature=0.0,
188
203
  n=1,
189
204
  parallel_tool_calls=True if self.tools else NOT_GIVEN,
190
205
  tools=self.tools or NOT_GIVEN,
191
206
  tool_choice=tool_choice or NOT_GIVEN,
192
207
  stream=stream,
208
+ max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
209
+ temperature=model_settings.get('temperature', NOT_GIVEN),
210
+ top_p=model_settings.get('top_p', NOT_GIVEN),
211
+ timeout=model_settings.get('timeout', NOT_GIVEN),
193
212
  )
194
213
 
195
214
  @staticmethod
196
- def _process_response(response: chat.ChatCompletion) -> ModelAnyResponse:
215
+ def _process_response(response: chat.ChatCompletion) -> ModelResponse:
197
216
  """Process a non-streamed response, and prepare a message to return."""
198
217
  timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
199
218
  choice = response.choices[0]
219
+ items: list[ModelResponsePart] = []
220
+ if choice.message.content is not None:
221
+ items.append(TextPart(choice.message.content))
200
222
  if choice.message.tool_calls is not None:
201
- return ModelStructuredResponse(
202
- [ToolCall.from_json(c.function.name, c.function.arguments, c.id) for c in choice.message.tool_calls],
203
- timestamp=timestamp,
204
- )
205
- else:
206
- assert choice.message.content is not None, choice
207
- return ModelTextResponse(choice.message.content, timestamp=timestamp)
223
+ for c in choice.message.tool_calls:
224
+ items.append(ToolCallPart.from_json(c.function.name, c.function.arguments, c.id))
225
+ return ModelResponse(items, timestamp=timestamp)
208
226
 
209
227
  @staticmethod
210
228
  async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
@@ -233,47 +251,55 @@ class GroqAgentModel(AgentModel):
233
251
  start_cost,
234
252
  )
235
253
 
236
- @staticmethod
237
- def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
254
+ @classmethod
255
+ def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
238
256
  """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
239
- if message.role == 'system':
240
- # SystemPrompt ->
241
- return chat.ChatCompletionSystemMessageParam(role='system', content=message.content)
242
- elif message.role == 'user':
243
- # UserPrompt ->
244
- return chat.ChatCompletionUserMessageParam(role='user', content=message.content)
245
- elif message.role == 'tool-return':
246
- # ToolReturn ->
247
- return chat.ChatCompletionToolMessageParam(
248
- role='tool',
249
- tool_call_id=_guard_tool_id(message),
250
- content=message.model_response_str(),
251
- )
252
- elif message.role == 'retry-prompt':
253
- # RetryPrompt ->
254
- if message.tool_name is None:
255
- return chat.ChatCompletionUserMessageParam(role='user', content=message.model_response())
256
- else:
257
- return chat.ChatCompletionToolMessageParam(
258
- role='tool',
259
- tool_call_id=_guard_tool_id(message),
260
- content=message.model_response(),
261
- )
262
- elif message.role == 'model-text-response':
263
- # ModelTextResponse ->
264
- return chat.ChatCompletionAssistantMessageParam(role='assistant', content=message.content)
265
- elif message.role == 'model-structured-response':
266
- assert (
267
- message.role == 'model-structured-response'
268
- ), f'Expected role to be "llm-tool-calls", got {message.role}'
269
- # ModelStructuredResponse ->
270
- return chat.ChatCompletionAssistantMessageParam(
271
- role='assistant',
272
- tool_calls=[_map_tool_call(t) for t in message.calls],
273
- )
257
+ if isinstance(message, ModelRequest):
258
+ yield from cls._map_user_message(message)
259
+ elif isinstance(message, ModelResponse):
260
+ texts: list[str] = []
261
+ tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
262
+ for item in message.parts:
263
+ if isinstance(item, TextPart):
264
+ texts.append(item.content)
265
+ elif isinstance(item, ToolCallPart):
266
+ tool_calls.append(_map_tool_call(item))
267
+ else:
268
+ assert_never(item)
269
+ message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
270
+ if texts:
271
+ # Note: model responses from this model should only have one text item, so the following
272
+ # shouldn't merge multiple texts into one unless you switch models between runs:
273
+ message_param['content'] = '\n\n'.join(texts)
274
+ if tool_calls:
275
+ message_param['tool_calls'] = tool_calls
276
+ yield message_param
274
277
  else:
275
278
  assert_never(message)
276
279
 
280
+ @classmethod
281
+ def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
282
+ for part in message.parts:
283
+ if isinstance(part, SystemPromptPart):
284
+ yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
285
+ elif isinstance(part, UserPromptPart):
286
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
287
+ elif isinstance(part, ToolReturnPart):
288
+ yield chat.ChatCompletionToolMessageParam(
289
+ role='tool',
290
+ tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
291
+ content=part.model_response_str(),
292
+ )
293
+ elif isinstance(part, RetryPromptPart):
294
+ if part.tool_name is None:
295
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
296
+ else:
297
+ yield chat.ChatCompletionToolMessageParam(
298
+ role='tool',
299
+ tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
300
+ content=part.model_response(),
301
+ )
302
+
277
303
 
278
304
  @dataclass
279
305
  class GroqStreamTextResponse(StreamTextResponse):
@@ -349,14 +375,14 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
349
375
  else:
350
376
  self._delta_tool_calls[new.index] = new
351
377
 
352
- def get(self, *, final: bool = False) -> ModelStructuredResponse:
353
- calls: list[ToolCall] = []
378
+ def get(self, *, final: bool = False) -> ModelResponse:
379
+ items: list[ModelResponsePart] = []
354
380
  for c in self._delta_tool_calls.values():
355
381
  if f := c.function:
356
382
  if f.name is not None and f.arguments is not None:
357
- calls.append(ToolCall.from_json(f.name, f.arguments, c.id))
383
+ items.append(ToolCallPart.from_json(f.name, f.arguments, c.id))
358
384
 
359
- return ModelStructuredResponse(calls, timestamp=self._timestamp)
385
+ return ModelResponse(items, timestamp=self._timestamp)
360
386
 
361
387
  def cost(self) -> Cost:
362
388
  return self._cost
@@ -365,16 +391,10 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
365
391
  return self._timestamp
366
392
 
367
393
 
368
- def _guard_tool_id(t: ToolCall | ToolReturn | RetryPrompt) -> str:
369
- """Type guard that checks a `tool_id` is not None both for static typing and runtime."""
370
- assert t.tool_id is not None, f'Groq requires `tool_id` to be set: {t}'
371
- return t.tool_id
372
-
373
-
374
- def _map_tool_call(t: ToolCall) -> chat.ChatCompletionMessageToolCallParam:
394
+ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
375
395
  assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
376
396
  return chat.ChatCompletionMessageToolCallParam(
377
- id=_guard_tool_id(t),
397
+ id=_guard_tool_call_id(t=t, model_source='Groq'),
378
398
  type='function',
379
399
  function={'name': t.tool_name, 'arguments': t.args.args_json},
380
400
  )