pydantic-ai-slim 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,290 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from collections.abc import Iterable
4
+ from dataclasses import dataclass, field
5
+ from itertools import chain
6
+ from typing import Literal, Union, cast
7
+
8
+ from cohere import TextAssistantMessageContentItem
9
+ from httpx import AsyncClient as AsyncHTTPClient
10
+ from typing_extensions import assert_never
11
+
12
+ from .. import result
13
+ from .._utils import guard_tool_call_id as _guard_tool_call_id
14
+ from ..messages import (
15
+ ModelMessage,
16
+ ModelRequest,
17
+ ModelResponse,
18
+ ModelResponsePart,
19
+ RetryPromptPart,
20
+ SystemPromptPart,
21
+ TextPart,
22
+ ToolCallPart,
23
+ ToolReturnPart,
24
+ UserPromptPart,
25
+ )
26
+ from ..settings import ModelSettings
27
+ from ..tools import ToolDefinition
28
+ from . import (
29
+ AgentModel,
30
+ Model,
31
+ check_allow_model_requests,
32
+ )
33
+
34
+ try:
35
+ from cohere import (
36
+ AssistantChatMessageV2,
37
+ AsyncClientV2,
38
+ ChatMessageV2,
39
+ ChatResponse,
40
+ SystemChatMessageV2,
41
+ ToolCallV2,
42
+ ToolCallV2Function,
43
+ ToolChatMessageV2,
44
+ ToolV2,
45
+ ToolV2Function,
46
+ UserChatMessageV2,
47
+ )
48
+ from cohere.v2.client import OMIT
49
+ except ImportError as _import_error:
50
+ raise ImportError(
51
+ 'Please install `cohere` to use the Cohere model, '
52
+ "you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
53
+ ) from _import_error
54
+
55
+ NamedCohereModels = Literal[
56
+ 'c4ai-aya-expanse-32b',
57
+ 'c4ai-aya-expanse-8b',
58
+ 'command',
59
+ 'command-light',
60
+ 'command-light-nightly',
61
+ 'command-nightly',
62
+ 'command-r',
63
+ 'command-r-03-2024',
64
+ 'command-r-08-2024',
65
+ 'command-r-plus',
66
+ 'command-r-plus-04-2024',
67
+ 'command-r-plus-08-2024',
68
+ 'command-r7b-12-2024',
69
+ ]
70
+ """Latest / most popular named Cohere models."""
71
+
72
+ CohereModelName = Union[NamedCohereModels, str]
73
+
74
+
75
+ class CohereModelSettings(ModelSettings):
76
+ """Settings used for a Cohere model request."""
77
+
78
+ # This class is a placeholder for any future cohere-specific settings
79
+
80
+
81
+ @dataclass(init=False)
82
+ class CohereModel(Model):
83
+ """A model that uses the Cohere API.
84
+
85
+ Internally, this uses the [Cohere Python client](
86
+ https://github.com/cohere-ai/cohere-python) to interact with the API.
87
+
88
+ Apart from `__init__`, all methods are private or match those of the base class.
89
+ """
90
+
91
+ model_name: CohereModelName
92
+ client: AsyncClientV2 = field(repr=False)
93
+
94
+ def __init__(
95
+ self,
96
+ model_name: CohereModelName,
97
+ *,
98
+ api_key: str | None = None,
99
+ cohere_client: AsyncClientV2 | None = None,
100
+ http_client: AsyncHTTPClient | None = None,
101
+ ):
102
+ """Initialize an Cohere model.
103
+
104
+ Args:
105
+ model_name: The name of the Cohere model to use. List of model names
106
+ available [here](https://docs.cohere.com/docs/models#command).
107
+ api_key: The API key to use for authentication, if not provided, the
108
+ `CO_API_KEY` environment variable will be used if available.
109
+ cohere_client: An existing Cohere async client to use. If provided,
110
+ `api_key` and `http_client` must be `None`.
111
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
112
+ """
113
+ self.model_name: CohereModelName = model_name
114
+ if cohere_client is not None:
115
+ assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
116
+ assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
117
+ self.client = cohere_client
118
+ else:
119
+ self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore
120
+
121
+ async def agent_model(
122
+ self,
123
+ *,
124
+ function_tools: list[ToolDefinition],
125
+ allow_text_result: bool,
126
+ result_tools: list[ToolDefinition],
127
+ ) -> AgentModel:
128
+ check_allow_model_requests()
129
+ tools = [self._map_tool_definition(r) for r in function_tools]
130
+ if result_tools:
131
+ tools += [self._map_tool_definition(r) for r in result_tools]
132
+ return CohereAgentModel(
133
+ self.client,
134
+ self.model_name,
135
+ allow_text_result,
136
+ tools,
137
+ )
138
+
139
+ def name(self) -> str:
140
+ return f'cohere:{self.model_name}'
141
+
142
+ @staticmethod
143
+ def _map_tool_definition(f: ToolDefinition) -> ToolV2:
144
+ return ToolV2(
145
+ type='function',
146
+ function=ToolV2Function(
147
+ name=f.name,
148
+ description=f.description,
149
+ parameters=f.parameters_json_schema,
150
+ ),
151
+ )
152
+
153
+
154
+ @dataclass
155
+ class CohereAgentModel(AgentModel):
156
+ """Implementation of `AgentModel` for Cohere models."""
157
+
158
+ client: AsyncClientV2
159
+ model_name: CohereModelName
160
+ allow_text_result: bool
161
+ tools: list[ToolV2]
162
+
163
+ async def request(
164
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
165
+ ) -> tuple[ModelResponse, result.Usage]:
166
+ response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}))
167
+ return self._process_response(response), _map_usage(response)
168
+
169
+ async def _chat(
170
+ self,
171
+ messages: list[ModelMessage],
172
+ model_settings: CohereModelSettings,
173
+ ) -> ChatResponse:
174
+ cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
175
+ return await self.client.chat(
176
+ model=self.model_name,
177
+ messages=cohere_messages,
178
+ tools=self.tools or OMIT,
179
+ max_tokens=model_settings.get('max_tokens', OMIT),
180
+ temperature=model_settings.get('temperature', OMIT),
181
+ p=model_settings.get('top_p', OMIT),
182
+ seed=model_settings.get('seed', OMIT),
183
+ presence_penalty=model_settings.get('presence_penalty', OMIT),
184
+ frequency_penalty=model_settings.get('frequency_penalty', OMIT),
185
+ )
186
+
187
+ def _process_response(self, response: ChatResponse) -> ModelResponse:
188
+ """Process a non-streamed response, and prepare a message to return."""
189
+ parts: list[ModelResponsePart] = []
190
+ if response.message.content is not None and len(response.message.content) > 0:
191
+ # While Cohere's API returns a list, it only does that for future proofing
192
+ # and currently only one item is being returned.
193
+ choice = response.message.content[0]
194
+ parts.append(TextPart(choice.text))
195
+ for c in response.message.tool_calls or []:
196
+ if c.function and c.function.name and c.function.arguments:
197
+ parts.append(
198
+ ToolCallPart(
199
+ tool_name=c.function.name,
200
+ args=c.function.arguments,
201
+ tool_call_id=c.id,
202
+ )
203
+ )
204
+ return ModelResponse(parts=parts, model_name=self.model_name)
205
+
206
+ @classmethod
207
+ def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]:
208
+ """Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
209
+ if isinstance(message, ModelRequest):
210
+ yield from cls._map_user_message(message)
211
+ elif isinstance(message, ModelResponse):
212
+ texts: list[str] = []
213
+ tool_calls: list[ToolCallV2] = []
214
+ for item in message.parts:
215
+ if isinstance(item, TextPart):
216
+ texts.append(item.content)
217
+ elif isinstance(item, ToolCallPart):
218
+ tool_calls.append(_map_tool_call(item))
219
+ else:
220
+ assert_never(item)
221
+ message_param = AssistantChatMessageV2(role='assistant')
222
+ if texts:
223
+ message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))]
224
+ if tool_calls:
225
+ message_param.tool_calls = tool_calls
226
+ yield message_param
227
+ else:
228
+ assert_never(message)
229
+
230
+ @classmethod
231
+ def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
232
+ for part in message.parts:
233
+ if isinstance(part, SystemPromptPart):
234
+ yield SystemChatMessageV2(role='system', content=part.content)
235
+ elif isinstance(part, UserPromptPart):
236
+ yield UserChatMessageV2(role='user', content=part.content)
237
+ elif isinstance(part, ToolReturnPart):
238
+ yield ToolChatMessageV2(
239
+ role='tool',
240
+ tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
241
+ content=part.model_response_str(),
242
+ )
243
+ elif isinstance(part, RetryPromptPart):
244
+ if part.tool_name is None:
245
+ yield UserChatMessageV2(role='user', content=part.model_response())
246
+ else:
247
+ yield ToolChatMessageV2(
248
+ role='tool',
249
+ tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
250
+ content=part.model_response(),
251
+ )
252
+ else:
253
+ assert_never(part)
254
+
255
+
256
+ def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
257
+ return ToolCallV2(
258
+ id=_guard_tool_call_id(t=t, model_source='Cohere'),
259
+ type='function',
260
+ function=ToolCallV2Function(
261
+ name=t.tool_name,
262
+ arguments=t.args_as_json_str(),
263
+ ),
264
+ )
265
+
266
+
267
+ def _map_usage(response: ChatResponse) -> result.Usage:
268
+ usage = response.usage
269
+ if usage is None:
270
+ return result.Usage()
271
+ else:
272
+ details: dict[str, int] = {}
273
+ if usage.billed_units is not None:
274
+ if usage.billed_units.input_tokens:
275
+ details['input_tokens'] = int(usage.billed_units.input_tokens)
276
+ if usage.billed_units.output_tokens:
277
+ details['output_tokens'] = int(usage.billed_units.output_tokens)
278
+ if usage.billed_units.search_units:
279
+ details['search_units'] = int(usage.billed_units.search_units)
280
+ if usage.billed_units.classifications:
281
+ details['classifications'] = int(usage.billed_units.classifications)
282
+
283
+ request_tokens = int(usage.tokens.input_tokens) if usage.tokens and usage.tokens.input_tokens else None
284
+ response_tokens = int(usage.tokens.output_tokens) if usage.tokens and usage.tokens.output_tokens else None
285
+ return result.Usage(
286
+ request_tokens=request_tokens,
287
+ response_tokens=response_tokens,
288
+ total_tokens=(request_tokens or 0) + (response_tokens or 0),
289
+ details=details,
290
+ )
@@ -71,16 +71,15 @@ class FunctionModel(Model):
71
71
  result_tools: list[ToolDefinition],
72
72
  ) -> AgentModel:
73
73
  return FunctionAgentModel(
74
- self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools, None)
74
+ self.function,
75
+ self.stream_function,
76
+ AgentInfo(function_tools, allow_text_result, result_tools, None),
75
77
  )
76
78
 
77
79
  def name(self) -> str:
78
- labels: list[str] = []
79
- if self.function is not None:
80
- labels.append(self.function.__name__)
81
- if self.stream_function is not None:
82
- labels.append(f'stream-{self.stream_function.__name__}')
83
- return f'function:{",".join(labels)}'
80
+ function_name = self.function.__name__ if self.function is not None else ''
81
+ stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
82
+ return f'function:{function_name}:{stream_function_name}'
84
83
 
85
84
 
86
85
  @dataclass(frozen=True)
@@ -147,12 +146,15 @@ class FunctionAgentModel(AgentModel):
147
146
  agent_info = replace(self.agent_info, model_settings=model_settings)
148
147
 
149
148
  assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
149
+ model_name = f'function:{self.function.__name__}'
150
+
150
151
  if inspect.iscoroutinefunction(self.function):
151
152
  response = await self.function(messages, agent_info)
152
153
  else:
153
154
  response_ = await _utils.run_in_executor(self.function, messages, agent_info)
154
155
  assert isinstance(response_, ModelResponse), response_
155
156
  response = response_
157
+ response.model_name = model_name
156
158
  # TODO is `messages` right here? Should it just be new messages?
157
159
  return response, _estimate_usage(chain(messages, [response]))
158
160
 
@@ -163,13 +165,15 @@ class FunctionAgentModel(AgentModel):
163
165
  assert (
164
166
  self.stream_function is not None
165
167
  ), 'FunctionModel must receive a `stream_function` to support streamed requests'
168
+ model_name = f'function:{self.stream_function.__name__}'
169
+
166
170
  response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info))
167
171
 
168
172
  first = await response_stream.peek()
169
173
  if isinstance(first, _utils.Unset):
170
174
  raise ValueError('Stream function must return at least one item')
171
175
 
172
- yield FunctionStreamedResponse(response_stream)
176
+ yield FunctionStreamedResponse(_model_name=model_name, _iter=response_stream)
173
177
 
174
178
 
175
179
  @dataclass
@@ -7,7 +7,7 @@ from contextlib import asynccontextmanager
7
7
  from copy import deepcopy
8
8
  from dataclasses import dataclass, field
9
9
  from datetime import datetime
10
- from typing import Annotated, Any, Literal, Protocol, Union
10
+ from typing import Annotated, Any, Literal, Protocol, Union, cast
11
11
  from uuid import uuid4
12
12
 
13
13
  import pydantic
@@ -48,6 +48,12 @@ See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#mo
48
48
  """
49
49
 
50
50
 
51
+ class GeminiModelSettings(ModelSettings):
52
+ """Settings used for a Gemini model request."""
53
+
54
+ # This class is a placeholder for any future gemini-specific settings
55
+
56
+
51
57
  @dataclass(init=False)
52
58
  class GeminiModel(Model):
53
59
  """A model that uses Gemini via `generativelanguage.googleapis.com` API.
@@ -99,6 +105,7 @@ class GeminiModel(Model):
99
105
  allow_text_result: bool,
100
106
  result_tools: list[ToolDefinition],
101
107
  ) -> GeminiAgentModel:
108
+ check_allow_model_requests()
102
109
  return GeminiAgentModel(
103
110
  http_client=self.http_client,
104
111
  model_name=self.model_name,
@@ -151,7 +158,6 @@ class GeminiAgentModel(AgentModel):
151
158
  allow_text_result: bool,
152
159
  result_tools: list[ToolDefinition],
153
160
  ):
154
- check_allow_model_requests()
155
161
  tools = [_function_from_abstract_tool(t) for t in function_tools]
156
162
  if result_tools:
157
163
  tools += [_function_from_abstract_tool(t) for t in result_tools]
@@ -171,7 +177,9 @@ class GeminiAgentModel(AgentModel):
171
177
  async def request(
172
178
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
173
179
  ) -> tuple[ModelResponse, usage.Usage]:
174
- async with self._make_request(messages, False, model_settings) as http_response:
180
+ async with self._make_request(
181
+ messages, False, cast(GeminiModelSettings, model_settings or {})
182
+ ) as http_response:
175
183
  response = _gemini_response_ta.validate_json(await http_response.aread())
176
184
  return self._process_response(response), _metadata_as_usage(response)
177
185
 
@@ -179,12 +187,12 @@ class GeminiAgentModel(AgentModel):
179
187
  async def request_stream(
180
188
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
181
189
  ) -> AsyncIterator[StreamedResponse]:
182
- async with self._make_request(messages, True, model_settings) as http_response:
190
+ async with self._make_request(messages, True, cast(GeminiModelSettings, model_settings or {})) as http_response:
183
191
  yield await self._process_streamed_response(http_response)
184
192
 
185
193
  @asynccontextmanager
186
194
  async def _make_request(
187
- self, messages: list[ModelMessage], streamed: bool, model_settings: ModelSettings | None
195
+ self, messages: list[ModelMessage], streamed: bool, model_settings: GeminiModelSettings
188
196
  ) -> AsyncIterator[HTTPResponse]:
189
197
  sys_prompt_parts, contents = self._message_to_gemini_content(messages)
190
198
 
@@ -204,6 +212,10 @@ class GeminiAgentModel(AgentModel):
204
212
  generation_config['temperature'] = temperature
205
213
  if (top_p := model_settings.get('top_p')) is not None:
206
214
  generation_config['top_p'] = top_p
215
+ if (presence_penalty := model_settings.get('presence_penalty')) is not None:
216
+ generation_config['presence_penalty'] = presence_penalty
217
+ if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
218
+ generation_config['frequency_penalty'] = frequency_penalty
207
219
  if generation_config:
208
220
  request_data['generation_config'] = generation_config
209
221
 
@@ -222,22 +234,20 @@ class GeminiAgentModel(AgentModel):
222
234
  url,
223
235
  content=request_json,
224
236
  headers=headers,
225
- timeout=(model_settings or {}).get('timeout', USE_CLIENT_DEFAULT),
237
+ timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT),
226
238
  ) as r:
227
239
  if r.status_code != 200:
228
240
  await r.aread()
229
241
  raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
230
242
  yield r
231
243
 
232
- @staticmethod
233
- def _process_response(response: _GeminiResponse) -> ModelResponse:
244
+ def _process_response(self, response: _GeminiResponse) -> ModelResponse:
234
245
  if len(response['candidates']) != 1:
235
246
  raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
236
247
  parts = response['candidates'][0]['content']['parts']
237
- return _process_response_from_parts(parts)
248
+ return _process_response_from_parts(parts, model_name=self.model_name)
238
249
 
239
- @staticmethod
240
- async def _process_streamed_response(http_response: HTTPResponse) -> StreamedResponse:
250
+ async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
241
251
  """Process a streamed response, and prepare a streaming response to return."""
242
252
  aiter_bytes = http_response.aiter_bytes()
243
253
  start_response: _GeminiResponse | None = None
@@ -258,7 +268,7 @@ class GeminiAgentModel(AgentModel):
258
268
  if start_response is None:
259
269
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
260
270
 
261
- return GeminiStreamedResponse(_content=content, _stream=aiter_bytes)
271
+ return GeminiStreamedResponse(_model_name=self.model_name, _content=content, _stream=aiter_bytes)
262
272
 
263
273
  @classmethod
264
274
  def _message_to_gemini_content(
@@ -400,6 +410,8 @@ class _GeminiGenerationConfig(TypedDict, total=False):
400
410
  max_output_tokens: int
401
411
  temperature: float
402
412
  top_p: float
413
+ presence_penalty: float
414
+ frequency_penalty: float
403
415
 
404
416
 
405
417
  class _GeminiContent(TypedDict):
@@ -432,14 +444,16 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
432
444
  return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict()))
433
445
 
434
446
 
435
- def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: datetime | None = None) -> ModelResponse:
447
+ def _process_response_from_parts(
448
+ parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, timestamp: datetime | None = None
449
+ ) -> ModelResponse:
436
450
  items: list[ModelResponsePart] = []
437
451
  for part in parts:
438
452
  if 'text' in part:
439
453
  items.append(TextPart(content=part['text']))
440
454
  elif 'function_call' in part:
441
455
  items.append(
442
- ToolCallPart.from_raw_args(
456
+ ToolCallPart(
443
457
  tool_name=part['function_call']['name'],
444
458
  args=part['function_call']['args'],
445
459
  )
@@ -448,7 +462,7 @@ def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: d
448
462
  raise exceptions.UnexpectedModelBehavior(
449
463
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
450
464
  )
451
- return ModelResponse(items, timestamp=timestamp or _utils.now_utc())
465
+ return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc())
452
466
 
453
467
 
454
468
  class _GeminiFunctionCall(TypedDict):
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
7
7
  from itertools import chain
8
- from typing import Literal, overload
8
+ from typing import Literal, cast, overload
9
9
 
10
10
  from httpx import AsyncClient as AsyncHTTPClient
11
11
  from typing_extensions import assert_never
@@ -47,10 +47,7 @@ except ImportError as _import_error:
47
47
 
48
48
  GroqModelName = Literal[
49
49
  'llama-3.3-70b-versatile',
50
- 'llama-3.1-70b-versatile',
51
- 'llama3-groq-70b-8192-tool-use-preview',
52
- 'llama3-groq-8b-8192-tool-use-preview',
53
- 'llama-3.1-70b-specdec',
50
+ 'llama-3.3-70b-specdec',
54
51
  'llama-3.1-8b-instant',
55
52
  'llama-3.2-1b-preview',
56
53
  'llama-3.2-3b-preview',
@@ -60,7 +57,6 @@ GroqModelName = Literal[
60
57
  'llama3-8b-8192',
61
58
  'mixtral-8x7b-32768',
62
59
  'gemma2-9b-it',
63
- 'gemma-7b-it',
64
60
  ]
65
61
  """Named Groq models.
66
62
 
@@ -68,6 +64,12 @@ See [the Groq docs](https://console.groq.com/docs/models) for a full list.
68
64
  """
69
65
 
70
66
 
67
+ class GroqModelSettings(ModelSettings):
68
+ """Settings used for a Groq model request."""
69
+
70
+ # This class is a placeholder for any future groq-specific settings
71
+
72
+
71
73
  @dataclass(init=False)
72
74
  class GroqModel(Model):
73
75
  """A model that uses the Groq API.
@@ -155,31 +157,31 @@ class GroqAgentModel(AgentModel):
155
157
  async def request(
156
158
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
157
159
  ) -> tuple[ModelResponse, usage.Usage]:
158
- response = await self._completions_create(messages, False, model_settings)
160
+ response = await self._completions_create(messages, False, cast(GroqModelSettings, model_settings or {}))
159
161
  return self._process_response(response), _map_usage(response)
160
162
 
161
163
  @asynccontextmanager
162
164
  async def request_stream(
163
165
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
164
166
  ) -> AsyncIterator[StreamedResponse]:
165
- response = await self._completions_create(messages, True, model_settings)
167
+ response = await self._completions_create(messages, True, cast(GroqModelSettings, model_settings or {}))
166
168
  async with response:
167
169
  yield await self._process_streamed_response(response)
168
170
 
169
171
  @overload
170
172
  async def _completions_create(
171
- self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
173
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: GroqModelSettings
172
174
  ) -> AsyncStream[ChatCompletionChunk]:
173
175
  pass
174
176
 
175
177
  @overload
176
178
  async def _completions_create(
177
- self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
179
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: GroqModelSettings
178
180
  ) -> chat.ChatCompletion:
179
181
  pass
180
182
 
181
183
  async def _completions_create(
182
- self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
184
+ self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings
183
185
  ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
184
186
  # standalone function to make it easier to override
185
187
  if not self.tools:
@@ -191,13 +193,11 @@ class GroqAgentModel(AgentModel):
191
193
 
192
194
  groq_messages = list(chain(*(self._map_message(m) for m in messages)))
193
195
 
194
- model_settings = model_settings or {}
195
-
196
196
  return await self.client.chat.completions.create(
197
197
  model=str(self.model_name),
198
198
  messages=groq_messages,
199
199
  n=1,
200
- parallel_tool_calls=True if self.tools else NOT_GIVEN,
200
+ parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
201
201
  tools=self.tools or NOT_GIVEN,
202
202
  tool_choice=tool_choice or NOT_GIVEN,
203
203
  stream=stream,
@@ -205,10 +205,13 @@ class GroqAgentModel(AgentModel):
205
205
  temperature=model_settings.get('temperature', NOT_GIVEN),
206
206
  top_p=model_settings.get('top_p', NOT_GIVEN),
207
207
  timeout=model_settings.get('timeout', NOT_GIVEN),
208
+ seed=model_settings.get('seed', NOT_GIVEN),
209
+ presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
210
+ frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
211
+ logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
208
212
  )
209
213
 
210
- @staticmethod
211
- def _process_response(response: chat.ChatCompletion) -> ModelResponse:
214
+ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
212
215
  """Process a non-streamed response, and prepare a message to return."""
213
216
  timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
214
217
  choice = response.choices[0]
@@ -217,20 +220,21 @@ class GroqAgentModel(AgentModel):
217
220
  items.append(TextPart(content=choice.message.content))
218
221
  if choice.message.tool_calls is not None:
219
222
  for c in choice.message.tool_calls:
220
- items.append(
221
- ToolCallPart.from_raw_args(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)
222
- )
223
- return ModelResponse(items, timestamp=timestamp)
223
+ items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
224
+ return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
224
225
 
225
- @staticmethod
226
- async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
226
+ async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
227
227
  """Process a streamed response, and prepare a streaming response to return."""
228
228
  peekable_response = _utils.PeekableAsyncStream(response)
229
229
  first_chunk = await peekable_response.peek()
230
230
  if isinstance(first_chunk, _utils.Unset):
231
231
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
232
232
 
233
- return GroqStreamedResponse(peekable_response, datetime.fromtimestamp(first_chunk.created, tz=timezone.utc))
233
+ return GroqStreamedResponse(
234
+ _response=peekable_response,
235
+ _model_name=self.model_name,
236
+ _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
237
+ )
234
238
 
235
239
  @classmethod
236
240
  def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]: