pydantic-ai-slim 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -1,28 +1,34 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
3
+ from collections.abc import AsyncIterator, Iterable
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
7
+ from itertools import chain
7
8
  from typing import Literal, overload
8
9
 
9
10
  from httpx import AsyncClient as AsyncHTTPClient
10
11
  from typing_extensions import assert_never
11
12
 
12
13
  from .. import UnexpectedModelBehavior, _utils, result
14
+ from .._utils import guard_tool_call_id as _guard_tool_call_id
13
15
  from ..messages import (
14
16
  ArgsJson,
15
- Message,
16
- ModelAnyResponse,
17
- ModelStructuredResponse,
18
- ModelTextResponse,
19
- RetryPrompt,
20
- ToolCall,
21
- ToolReturn,
17
+ ModelMessage,
18
+ ModelRequest,
19
+ ModelResponse,
20
+ ModelResponsePart,
21
+ RetryPromptPart,
22
+ SystemPromptPart,
23
+ TextPart,
24
+ ToolCallPart,
25
+ ToolReturnPart,
26
+ UserPromptPart,
22
27
  )
23
28
  from ..result import Cost
29
+ from ..settings import ModelSettings
30
+ from ..tools import ToolDefinition
24
31
  from . import (
25
- AbstractToolDefinition,
26
32
  AgentModel,
27
33
  EitherStreamedResponse,
28
34
  Model,
@@ -37,13 +43,14 @@ try:
37
43
  from groq.types import chat
38
44
  from groq.types.chat import ChatCompletion, ChatCompletionChunk
39
45
  from groq.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
40
- except ImportError as e:
46
+ except ImportError as _import_error:
41
47
  raise ImportError(
42
48
  'Please install `groq` to use the Groq model, '
43
- "you can use the `groq` optional group — `pip install 'pydantic-ai[groq]'`"
44
- ) from e
49
+ "you can use the `groq` optional group — `pip install 'pydantic-ai-slim[groq]'`"
50
+ ) from _import_error
45
51
 
46
52
  GroqModelName = Literal[
53
+ 'llama-3.3-70b-versatile',
47
54
  'llama-3.1-70b-versatile',
48
55
  'llama3-groq-70b-8192-tool-use-preview',
49
56
  'llama3-groq-8b-8192-tool-use-preview',
@@ -109,13 +116,14 @@ class GroqModel(Model):
109
116
 
110
117
  async def agent_model(
111
118
  self,
112
- function_tools: Mapping[str, AbstractToolDefinition],
119
+ *,
120
+ function_tools: list[ToolDefinition],
113
121
  allow_text_result: bool,
114
- result_tools: Sequence[AbstractToolDefinition] | None,
122
+ result_tools: list[ToolDefinition],
115
123
  ) -> AgentModel:
116
124
  check_allow_model_requests()
117
- tools = [self._map_tool_definition(r) for r in function_tools.values()]
118
- if result_tools is not None:
125
+ tools = [self._map_tool_definition(r) for r in function_tools]
126
+ if result_tools:
119
127
  tools += [self._map_tool_definition(r) for r in result_tools]
120
128
  return GroqAgentModel(
121
129
  self.client,
@@ -128,13 +136,13 @@ class GroqModel(Model):
128
136
  return f'groq:{self.model_name}'
129
137
 
130
138
  @staticmethod
131
- def _map_tool_definition(f: AbstractToolDefinition) -> chat.ChatCompletionToolParam:
139
+ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
132
140
  return {
133
141
  'type': 'function',
134
142
  'function': {
135
143
  'name': f.name,
136
144
  'description': f.description,
137
- 'parameters': f.json_schema,
145
+ 'parameters': f.parameters_json_schema,
138
146
  },
139
147
  }
140
148
 
@@ -148,28 +156,34 @@ class GroqAgentModel(AgentModel):
148
156
  allow_text_result: bool
149
157
  tools: list[chat.ChatCompletionToolParam]
150
158
 
151
- async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
152
- response = await self._completions_create(messages, False)
159
+ async def request(
160
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
161
+ ) -> tuple[ModelResponse, result.Cost]:
162
+ response = await self._completions_create(messages, False, model_settings)
153
163
  return self._process_response(response), _map_cost(response)
154
164
 
155
165
  @asynccontextmanager
156
- async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
157
- response = await self._completions_create(messages, True)
166
+ async def request_stream(
167
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
168
+ ) -> AsyncIterator[EitherStreamedResponse]:
169
+ response = await self._completions_create(messages, True, model_settings)
158
170
  async with response:
159
171
  yield await self._process_streamed_response(response)
160
172
 
161
173
  @overload
162
174
  async def _completions_create(
163
- self, messages: list[Message], stream: Literal[True]
175
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
164
176
  ) -> AsyncStream[ChatCompletionChunk]:
165
177
  pass
166
178
 
167
179
  @overload
168
- async def _completions_create(self, messages: list[Message], stream: Literal[False]) -> chat.ChatCompletion:
180
+ async def _completions_create(
181
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
182
+ ) -> chat.ChatCompletion:
169
183
  pass
170
184
 
171
185
  async def _completions_create(
172
- self, messages: list[Message], stream: bool
186
+ self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
173
187
  ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
174
188
  # standalone function to make it easier to override
175
189
  if not self.tools:
@@ -179,103 +193,112 @@ class GroqAgentModel(AgentModel):
179
193
  else:
180
194
  tool_choice = 'auto'
181
195
 
182
- groq_messages = [self._map_message(m) for m in messages]
196
+ groq_messages = list(chain(*(self._map_message(m) for m in messages)))
197
+
198
+ model_settings = model_settings or {}
199
+
183
200
  return await self.client.chat.completions.create(
184
201
  model=str(self.model_name),
185
202
  messages=groq_messages,
186
- temperature=0.0,
187
203
  n=1,
188
204
  parallel_tool_calls=True if self.tools else NOT_GIVEN,
189
205
  tools=self.tools or NOT_GIVEN,
190
206
  tool_choice=tool_choice or NOT_GIVEN,
191
207
  stream=stream,
208
+ max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
209
+ temperature=model_settings.get('temperature', NOT_GIVEN),
210
+ top_p=model_settings.get('top_p', NOT_GIVEN),
211
+ timeout=model_settings.get('timeout', NOT_GIVEN),
192
212
  )
193
213
 
194
214
  @staticmethod
195
- def _process_response(response: chat.ChatCompletion) -> ModelAnyResponse:
215
+ def _process_response(response: chat.ChatCompletion) -> ModelResponse:
196
216
  """Process a non-streamed response, and prepare a message to return."""
197
217
  timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
198
218
  choice = response.choices[0]
219
+ items: list[ModelResponsePart] = []
220
+ if choice.message.content is not None:
221
+ items.append(TextPart(choice.message.content))
199
222
  if choice.message.tool_calls is not None:
200
- return ModelStructuredResponse(
201
- [ToolCall.from_json(c.function.name, c.function.arguments, c.id) for c in choice.message.tool_calls],
202
- timestamp=timestamp,
203
- )
204
- else:
205
- assert choice.message.content is not None, choice
206
- return ModelTextResponse(choice.message.content, timestamp=timestamp)
223
+ for c in choice.message.tool_calls:
224
+ items.append(ToolCallPart.from_json(c.function.name, c.function.arguments, c.id))
225
+ return ModelResponse(items, timestamp=timestamp)
207
226
 
208
227
  @staticmethod
209
228
  async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
210
229
  """Process a streamed response, and prepare a streaming response to return."""
211
- try:
212
- first_chunk = await response.__anext__()
213
- except StopAsyncIteration as e: # pragma: no cover
214
- raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
215
- timestamp = datetime.fromtimestamp(first_chunk.created, tz=timezone.utc)
216
- delta = first_chunk.choices[0].delta
217
- start_cost = _map_cost(first_chunk)
218
-
219
- # the first chunk may only contain `role`, so we iterate until we get either `tool_calls` or `content`
220
- while delta.tool_calls is None and delta.content is None:
230
+ timestamp: datetime | None = None
231
+ start_cost = Cost()
232
+ # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
233
+ while True:
221
234
  try:
222
- next_chunk = await response.__anext__()
235
+ chunk = await response.__anext__()
223
236
  except StopAsyncIteration as e:
224
237
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
225
- delta = next_chunk.choices[0].delta
226
- start_cost += _map_cost(next_chunk)
227
-
228
- if delta.content is not None:
229
- return GroqStreamTextResponse(delta.content, response, timestamp, start_cost)
238
+ timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
239
+ start_cost += _map_cost(chunk)
240
+
241
+ if chunk.choices:
242
+ delta = chunk.choices[0].delta
243
+
244
+ if delta.content is not None:
245
+ return GroqStreamTextResponse(delta.content, response, timestamp, start_cost)
246
+ elif delta.tool_calls is not None:
247
+ return GroqStreamStructuredResponse(
248
+ response,
249
+ {c.index: c for c in delta.tool_calls},
250
+ timestamp,
251
+ start_cost,
252
+ )
253
+
254
+ @classmethod
255
+ def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
256
+ """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
257
+ if isinstance(message, ModelRequest):
258
+ yield from cls._map_user_message(message)
259
+ elif isinstance(message, ModelResponse):
260
+ texts: list[str] = []
261
+ tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
262
+ for item in message.parts:
263
+ if isinstance(item, TextPart):
264
+ texts.append(item.content)
265
+ elif isinstance(item, ToolCallPart):
266
+ tool_calls.append(_map_tool_call(item))
267
+ else:
268
+ assert_never(item)
269
+ message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
270
+ if texts:
271
+ # Note: model responses from this model should only have one text item, so the following
272
+ # shouldn't merge multiple texts into one unless you switch models between runs:
273
+ message_param['content'] = '\n\n'.join(texts)
274
+ if tool_calls:
275
+ message_param['tool_calls'] = tool_calls
276
+ yield message_param
230
277
  else:
231
- assert delta.tool_calls is not None, f'Expected delta with tool_calls, got {delta}'
232
- return GroqStreamStructuredResponse(
233
- response,
234
- {c.index: c for c in delta.tool_calls},
235
- timestamp,
236
- start_cost,
237
- )
278
+ assert_never(message)
238
279
 
239
- @staticmethod
240
- def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
241
- """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
242
- if message.role == 'system':
243
- # SystemPrompt ->
244
- return chat.ChatCompletionSystemMessageParam(role='system', content=message.content)
245
- elif message.role == 'user':
246
- # UserPrompt ->
247
- return chat.ChatCompletionUserMessageParam(role='user', content=message.content)
248
- elif message.role == 'tool-return':
249
- # ToolReturn ->
250
- return chat.ChatCompletionToolMessageParam(
251
- role='tool',
252
- tool_call_id=_guard_tool_id(message),
253
- content=message.model_response_str(),
254
- )
255
- elif message.role == 'retry-prompt':
256
- # RetryPrompt ->
257
- if message.tool_name is None:
258
- return chat.ChatCompletionUserMessageParam(role='user', content=message.model_response())
259
- else:
260
- return chat.ChatCompletionToolMessageParam(
280
+ @classmethod
281
+ def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
282
+ for part in message.parts:
283
+ if isinstance(part, SystemPromptPart):
284
+ yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
285
+ elif isinstance(part, UserPromptPart):
286
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
287
+ elif isinstance(part, ToolReturnPart):
288
+ yield chat.ChatCompletionToolMessageParam(
261
289
  role='tool',
262
- tool_call_id=_guard_tool_id(message),
263
- content=message.model_response(),
290
+ tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
291
+ content=part.model_response_str(),
264
292
  )
265
- elif message.role == 'model-text-response':
266
- # ModelTextResponse ->
267
- return chat.ChatCompletionAssistantMessageParam(role='assistant', content=message.content)
268
- elif message.role == 'model-structured-response':
269
- assert (
270
- message.role == 'model-structured-response'
271
- ), f'Expected role to be "llm-tool-calls", got {message.role}'
272
- # ModelStructuredResponse ->
273
- return chat.ChatCompletionAssistantMessageParam(
274
- role='assistant',
275
- tool_calls=[_map_tool_call(t) for t in message.calls],
276
- )
277
- else:
278
- assert_never(message)
293
+ elif isinstance(part, RetryPromptPart):
294
+ if part.tool_name is None:
295
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
296
+ else:
297
+ yield chat.ChatCompletionToolMessageParam(
298
+ role='tool',
299
+ tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
300
+ content=part.model_response(),
301
+ )
279
302
 
280
303
 
281
304
  @dataclass
@@ -352,14 +375,14 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
352
375
  else:
353
376
  self._delta_tool_calls[new.index] = new
354
377
 
355
- def get(self, *, final: bool = False) -> ModelStructuredResponse:
356
- calls: list[ToolCall] = []
378
+ def get(self, *, final: bool = False) -> ModelResponse:
379
+ items: list[ModelResponsePart] = []
357
380
  for c in self._delta_tool_calls.values():
358
381
  if f := c.function:
359
382
  if f.name is not None and f.arguments is not None:
360
- calls.append(ToolCall.from_json(f.name, f.arguments, c.id))
383
+ items.append(ToolCallPart.from_json(f.name, f.arguments, c.id))
361
384
 
362
- return ModelStructuredResponse(calls, timestamp=self._timestamp)
385
+ return ModelResponse(items, timestamp=self._timestamp)
363
386
 
364
387
  def cost(self) -> Cost:
365
388
  return self._cost
@@ -368,16 +391,10 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
368
391
  return self._timestamp
369
392
 
370
393
 
371
- def _guard_tool_id(t: ToolCall | ToolReturn | RetryPrompt) -> str:
372
- """Type guard that checks a `tool_id` is not None both for static typing and runtime."""
373
- assert t.tool_id is not None, f'Groq requires `tool_id` to be set: {t}'
374
- return t.tool_id
375
-
376
-
377
- def _map_tool_call(t: ToolCall) -> chat.ChatCompletionMessageToolCallParam:
394
+ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
378
395
  assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
379
396
  return chat.ChatCompletionMessageToolCallParam(
380
- id=_guard_tool_id(t),
397
+ id=_guard_tool_call_id(t=t, model_source='Groq'),
381
398
  type='function',
382
399
  function={'name': t.tool_name, 'arguments': t.args.args_json},
383
400
  )