pydantic-ai-slim 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -2,27 +2,29 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import re
4
4
  import string
5
- from collections.abc import AsyncIterator, Iterable, Iterator, Mapping, Sequence
5
+ from collections.abc import AsyncIterator, Iterable, Iterator
6
6
  from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field
8
8
  from datetime import date, datetime, timedelta
9
9
  from typing import Any, Literal
10
10
 
11
11
  import pydantic_core
12
+ from typing_extensions import assert_never
12
13
 
13
14
  from .. import _utils
14
15
  from ..messages import (
15
- Message,
16
- ModelAnyResponse,
17
- ModelStructuredResponse,
18
- ModelTextResponse,
19
- RetryPrompt,
20
- ToolCall,
21
- ToolReturn,
16
+ ModelMessage,
17
+ ModelRequest,
18
+ ModelResponse,
19
+ RetryPromptPart,
20
+ TextPart,
21
+ ToolCallPart,
22
+ ToolReturnPart,
22
23
  )
23
24
  from ..result import Cost
25
+ from ..settings import ModelSettings
26
+ from ..tools import ToolDefinition
24
27
  from . import (
25
- AbstractToolDefinition,
26
28
  AgentModel,
27
29
  EitherStreamedResponse,
28
30
  Model,
@@ -55,25 +57,38 @@ class TestModel(Model):
55
57
  """If set, these args will be passed to the result tool."""
56
58
  seed: int = 0
57
59
  """Seed for generating random data."""
58
- # these fields are set when the model is called by the agent
59
- agent_model_tools: Mapping[str, AbstractToolDefinition] | None = field(default=None, init=False)
60
+ agent_model_function_tools: list[ToolDefinition] | None = field(default=None, init=False)
61
+ """Definition of function tools passed to the model.
62
+
63
+ This is set when the model is called, so will reflect the function tools from the last step of the last run.
64
+ """
60
65
  agent_model_allow_text_result: bool | None = field(default=None, init=False)
61
- agent_model_result_tools: list[AbstractToolDefinition] | None = field(default=None, init=False)
66
+ """Whether plain text responses from the model are allowed.
67
+
68
+ This is set when the model is called, so will reflect the value from the last step of the last run.
69
+ """
70
+ agent_model_result_tools: list[ToolDefinition] | None = field(default=None, init=False)
71
+ """Definition of result tools passed to the model.
72
+
73
+ This is set when the model is called, so will reflect the result tools from the last step of the last run.
74
+ """
62
75
 
63
76
  async def agent_model(
64
77
  self,
65
- function_tools: Mapping[str, AbstractToolDefinition],
78
+ *,
79
+ function_tools: list[ToolDefinition],
66
80
  allow_text_result: bool,
67
- result_tools: Sequence[AbstractToolDefinition] | None,
81
+ result_tools: list[ToolDefinition],
68
82
  ) -> AgentModel:
69
- self.agent_model_tools = function_tools
83
+ self.agent_model_function_tools = function_tools
70
84
  self.agent_model_allow_text_result = allow_text_result
71
- self.agent_model_result_tools = list(result_tools) if result_tools is not None else None
85
+ self.agent_model_result_tools = result_tools
72
86
 
73
87
  if self.call_tools == 'all':
74
- tool_calls = [(r.name, r) for r in function_tools.values()]
88
+ tool_calls = [(r.name, r) for r in function_tools]
75
89
  else:
76
- tools_to_call = (function_tools[name] for name in self.call_tools)
90
+ function_tools_lookup = {t.name: t for t in function_tools}
91
+ tools_to_call = (function_tools_lookup[name] for name in self.call_tools)
77
92
  tool_calls = [(r.name, r) for r in tools_to_call]
78
93
 
79
94
  if self.custom_result_text is not None:
@@ -90,11 +105,12 @@ class TestModel(Model):
90
105
  result = _utils.Either(right=self.custom_result_args)
91
106
  elif allow_text_result:
92
107
  result = _utils.Either(left=None)
93
- elif result_tools is not None:
108
+ elif result_tools:
94
109
  result = _utils.Either(right=None)
95
110
  else:
96
111
  result = _utils.Either(left=None)
97
- return TestAgentModel(tool_calls, result, self.agent_model_result_tools, self.seed)
112
+
113
+ return TestAgentModel(tool_calls, result, result_tools, self.seed)
98
114
 
99
115
  def name(self) -> str:
100
116
  return 'test-model'
@@ -107,73 +123,89 @@ class TestAgentModel(AgentModel):
107
123
  # NOTE: Avoid test discovery by pytest.
108
124
  __test__ = False
109
125
 
110
- tool_calls: list[tuple[str, AbstractToolDefinition]]
126
+ tool_calls: list[tuple[str, ToolDefinition]]
111
127
  # left means the text is plain text; right means it's a function call
112
128
  result: _utils.Either[str | None, Any | None]
113
- result_tools: list[AbstractToolDefinition] | None
129
+ result_tools: list[ToolDefinition]
114
130
  seed: int
115
- step: int = 0
116
- last_message_count: int = 0
117
131
 
118
- async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, Cost]:
119
- return self._request(messages), Cost()
132
+ async def request(
133
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
134
+ ) -> tuple[ModelResponse, Cost]:
135
+ return self._request(messages, model_settings), Cost()
120
136
 
121
137
  @asynccontextmanager
122
- async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
123
- msg = self._request(messages)
138
+ async def request_stream(
139
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
140
+ ) -> AsyncIterator[EitherStreamedResponse]:
141
+ msg = self._request(messages, model_settings)
124
142
  cost = Cost()
125
- if isinstance(msg, ModelTextResponse):
126
- yield TestStreamTextResponse(msg.content, cost)
143
+
144
+ # TODO: Rework this once we make StreamTextResponse more general
145
+ texts: list[str] = []
146
+ tool_calls: list[ToolCallPart] = []
147
+ for item in msg.parts:
148
+ if isinstance(item, TextPart):
149
+ texts.append(item.content)
150
+ elif isinstance(item, ToolCallPart):
151
+ tool_calls.append(item)
152
+ else:
153
+ assert_never(item)
154
+
155
+ if texts:
156
+ yield TestStreamTextResponse('\n\n'.join(texts), cost)
127
157
  else:
128
158
  yield TestStreamStructuredResponse(msg, cost)
129
159
 
130
- def gen_tool_args(self, tool_def: AbstractToolDefinition) -> Any:
131
- return _JsonSchemaTestData(tool_def.json_schema, self.seed).generate()
132
-
133
- def _request(self, messages: list[Message]) -> ModelAnyResponse:
134
- if self.step == 0 and self.tool_calls:
135
- calls = [ToolCall.from_dict(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
136
- self.step += 1
137
- self.last_message_count = len(messages)
138
- return ModelStructuredResponse(calls=calls)
139
-
140
- new_messages = messages[self.last_message_count :]
141
- self.last_message_count = len(messages)
142
- new_retry_names = {m.tool_name for m in new_messages if isinstance(m, RetryPrompt)}
143
- if new_retry_names:
144
- calls = [
145
- ToolCall.from_dict(name, self.gen_tool_args(args))
146
- for name, args in self.tool_calls
147
- if name in new_retry_names
148
- ]
149
- self.step += 1
150
- return ModelStructuredResponse(calls=calls)
151
- else:
152
- if response_text := self.result.left:
153
- self.step += 1
154
- if response_text.value is None:
155
- # build up details of tool responses
156
- output: dict[str, Any] = {}
157
- for message in messages:
158
- if isinstance(message, ToolReturn):
159
- output[message.tool_name] = message.content
160
- if output:
161
- return ModelTextResponse(content=pydantic_core.to_json(output).decode())
162
- else:
163
- return ModelTextResponse(content='success (no tool calls)')
160
+ def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
161
+ return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
162
+
163
+ def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | None) -> ModelResponse:
164
+ # if there are tools, the first thing we want to do is call all of them
165
+ if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
166
+ return ModelResponse(
167
+ parts=[ToolCallPart.from_dict(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
168
+ )
169
+
170
+ if messages:
171
+ last_message = messages[-1]
172
+ assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.'
173
+
174
+ # check if there are any retry prompts, if so retry them
175
+ new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
176
+ if new_retry_names:
177
+ return ModelResponse(
178
+ parts=[
179
+ ToolCallPart.from_dict(name, self.gen_tool_args(args))
180
+ for name, args in self.tool_calls
181
+ if name in new_retry_names
182
+ ]
183
+ )
184
+
185
+ if response_text := self.result.left:
186
+ if response_text.value is None:
187
+ # build up details of tool responses
188
+ output: dict[str, Any] = {}
189
+ for message in messages:
190
+ if isinstance(message, ModelRequest):
191
+ for part in message.parts:
192
+ if isinstance(part, ToolReturnPart):
193
+ output[part.tool_name] = part.content
194
+ if output:
195
+ return ModelResponse.from_text(pydantic_core.to_json(output).decode())
164
196
  else:
165
- return ModelTextResponse(content=response_text.value)
197
+ return ModelResponse.from_text('success (no tool calls)')
166
198
  else:
167
- assert self.result_tools is not None, 'No result tools provided'
168
- custom_result_args = self.result.right
169
- result_tool = self.result_tools[self.seed % len(self.result_tools)]
170
- if custom_result_args is not None:
171
- self.step += 1
172
- return ModelStructuredResponse(calls=[ToolCall.from_dict(result_tool.name, custom_result_args)])
173
- else:
174
- response_args = self.gen_tool_args(result_tool)
175
- self.step += 1
176
- return ModelStructuredResponse(calls=[ToolCall.from_dict(result_tool.name, response_args)])
199
+ return ModelResponse.from_text(response_text.value)
200
+ else:
201
+ assert self.result_tools, 'No result tools provided'
202
+ custom_result_args = self.result.right
203
+ result_tool = self.result_tools[self.seed % len(self.result_tools)]
204
+ if custom_result_args is not None:
205
+ return ModelResponse(parts=[ToolCallPart.from_dict(result_tool.name, custom_result_args)])
206
+ else:
207
+ response_args = self.gen_tool_args(result_tool)
208
+ return ModelResponse(parts=[ToolCallPart.from_dict(result_tool.name, response_args)])
177
209
 
178
210
 
179
211
  @dataclass
@@ -213,7 +245,7 @@ class TestStreamTextResponse(StreamTextResponse):
213
245
  class TestStreamStructuredResponse(StreamStructuredResponse):
214
246
  """A structured response that streams test data."""
215
247
 
216
- _structured_response: ModelStructuredResponse
248
+ _structured_response: ModelResponse
217
249
  _cost: Cost
218
250
  _iter: Iterator[None] = field(default_factory=lambda: iter([None]))
219
251
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
@@ -221,7 +253,7 @@ class TestStreamStructuredResponse(StreamStructuredResponse):
221
253
  async def __anext__(self) -> None:
222
254
  return _utils.sync_anext(self._iter)
223
255
 
224
- def get(self, *, final: bool = False) -> ModelStructuredResponse:
256
+ def get(self, *, final: bool = False) -> ModelResponse:
225
257
  return self._structured_response
226
258
 
227
259
  def cost(self) -> Cost:
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import Mapping, Sequence
4
3
  from dataclasses import dataclass, field
5
4
  from datetime import datetime, timedelta
6
5
  from pathlib import Path
@@ -10,7 +9,8 @@ from httpx import AsyncClient as AsyncHTTPClient
10
9
 
11
10
  from .._utils import run_in_executor
12
11
  from ..exceptions import UserError
13
- from . import AbstractToolDefinition, Model, cached_async_http_client
12
+ from ..tools import ToolDefinition
13
+ from . import Model, cached_async_http_client
14
14
  from .gemini import GeminiAgentModel, GeminiModelName
15
15
 
16
16
  try:
@@ -18,11 +18,11 @@ try:
18
18
  from google.auth.credentials import Credentials as BaseCredentials
19
19
  from google.auth.transport.requests import Request
20
20
  from google.oauth2.service_account import Credentials as ServiceAccountCredentials
21
- except ImportError as e:
21
+ except ImportError as _import_error:
22
22
  raise ImportError(
23
23
  'Please install `google-auth` to use the VertexAI model, '
24
- "you can use the `vertexai` optional group — `pip install 'pydantic-ai[vertexai]'`"
25
- ) from e
24
+ "you can use the `vertexai` optional group — `pip install 'pydantic-ai-slim[vertexai]'`"
25
+ ) from _import_error
26
26
 
27
27
  VERTEX_AI_URL_TEMPLATE = (
28
28
  'https://{region}-aiplatform.googleapis.com/v1'
@@ -109,11 +109,12 @@ class VertexAIModel(Model):
109
109
 
110
110
  async def agent_model(
111
111
  self,
112
- function_tools: Mapping[str, AbstractToolDefinition],
112
+ *,
113
+ function_tools: list[ToolDefinition],
113
114
  allow_text_result: bool,
114
- result_tools: Sequence[AbstractToolDefinition] | None,
115
+ result_tools: list[ToolDefinition],
115
116
  ) -> GeminiAgentModel:
116
- url, auth = await self._ainit()
117
+ url, auth = await self.ainit()
117
118
  return GeminiAgentModel(
118
119
  http_client=self.http_client,
119
120
  model_name=self.model_name,
@@ -124,7 +125,11 @@ class VertexAIModel(Model):
124
125
  result_tools=result_tools,
125
126
  )
126
127
 
127
- async def _ainit(self) -> tuple[str, BearerTokenAuth]:
128
+ async def ainit(self) -> tuple[str, BearerTokenAuth]:
129
+ """Initialize the model, setting the URL and auth.
130
+
131
+ This will raise an error if authentication fails.
132
+ """
128
133
  if self.url is not None and self.auth is not None:
129
134
  return self.url, self.auth
130
135
 
pydantic_ai/result.py CHANGED
@@ -1,14 +1,14 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from collections.abc import AsyncIterator, Callable
4
+ from collections.abc import AsyncIterator, Awaitable, Callable
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime
7
7
  from typing import Generic, TypeVar, cast
8
8
 
9
9
  import logfire_api
10
10
 
11
- from . import _result, _utils, exceptions, messages, models
11
+ from . import _result, _utils, exceptions, messages as _messages, models
12
12
  from .tools import AgentDeps
13
13
 
14
14
  __all__ = (
@@ -71,19 +71,19 @@ class _BaseRunResult(ABC, Generic[ResultData]):
71
71
  You should not import or use this type directly, instead use its subclasses `RunResult` and `StreamedRunResult`.
72
72
  """
73
73
 
74
- _all_messages: list[messages.Message]
74
+ _all_messages: list[_messages.ModelMessage]
75
75
  _new_message_index: int
76
76
 
77
- def all_messages(self) -> list[messages.Message]:
78
- """Return the history of messages."""
77
+ def all_messages(self) -> list[_messages.ModelMessage]:
78
+ """Return the history of _messages."""
79
79
  # this is a method to be consistent with the other methods
80
80
  return self._all_messages
81
81
 
82
82
  def all_messages_json(self) -> bytes:
83
- """Return all messages from [`all_messages`][..all_messages] as JSON bytes."""
84
- return messages.MessagesTypeAdapter.dump_json(self.all_messages())
83
+ """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes."""
84
+ return _messages.ModelMessagesTypeAdapter.dump_json(self.all_messages())
85
85
 
86
- def new_messages(self) -> list[messages.Message]:
86
+ def new_messages(self) -> list[_messages.ModelMessage]:
87
87
  """Return new messages associated with this run.
88
88
 
89
89
  System prompts and any messages from older runs are excluded.
@@ -91,8 +91,8 @@ class _BaseRunResult(ABC, Generic[ResultData]):
91
91
  return self.all_messages()[self._new_message_index :]
92
92
 
93
93
  def new_messages_json(self) -> bytes:
94
- """Return new messages from [`new_messages`][..new_messages] as JSON bytes."""
95
- return messages.MessagesTypeAdapter.dump_json(self.new_messages())
94
+ """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes."""
95
+ return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages())
96
96
 
97
97
  @abstractmethod
98
98
  def cost(self) -> Cost:
@@ -122,7 +122,8 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
122
122
  _result_schema: _result.ResultSchema[ResultData] | None
123
123
  _deps: AgentDeps
124
124
  _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
125
- _on_complete: Callable[[list[messages.Message]], None]
125
+ _result_tool_name: str | None
126
+ _on_complete: Callable[[], Awaitable[None]]
126
127
  is_complete: bool = field(default=False, init=False)
127
128
  """Whether the stream has all been received.
128
129
 
@@ -205,11 +206,11 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
205
206
  combined = await self._validate_text_result(''.join(chunks))
206
207
  yield combined
207
208
  lf_span.set_attribute('combined_text', combined)
208
- self._marked_completed(text=combined)
209
+ await self._marked_completed(_messages.ModelResponse.from_text(combined))
209
210
 
210
211
  async def stream_structured(
211
212
  self, *, debounce_by: float | None = 0.1
212
- ) -> AsyncIterator[tuple[messages.ModelStructuredResponse, bool]]:
213
+ ) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]:
213
214
  """Stream the response as an async iterable of Structured LLM Messages.
214
215
 
215
216
  !!! note
@@ -230,17 +231,21 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
230
231
  else:
231
232
  # we should already have a message at this point, yield that first if it has any content
232
233
  msg = self._stream_response.get()
233
- if any(call.has_content() for call in msg.calls):
234
- yield msg, False
234
+ for item in msg.parts:
235
+ if isinstance(item, _messages.ToolCallPart) and item.has_content():
236
+ yield msg, False
237
+ break
235
238
  async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter:
236
239
  async for _ in group_iter:
237
240
  msg = self._stream_response.get()
238
- if any(call.has_content() for call in msg.calls):
239
- yield msg, False
241
+ for item in msg.parts:
242
+ if isinstance(item, _messages.ToolCallPart) and item.has_content():
243
+ yield msg, False
244
+ break
240
245
  msg = self._stream_response.get(final=True)
241
246
  yield msg, True
242
247
  lf_span.set_attribute('structured_response', msg)
243
- self._marked_completed(structured_message=msg)
248
+ await self._marked_completed(msg)
244
249
 
245
250
  async def get_data(self) -> ResultData:
246
251
  """Stream the whole response, validate and return it."""
@@ -249,12 +254,12 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
249
254
  if isinstance(self._stream_response, models.StreamTextResponse):
250
255
  text = ''.join(self._stream_response.get(final=True))
251
256
  text = await self._validate_text_result(text)
252
- self._marked_completed(text=text)
257
+ await self._marked_completed(_messages.ModelResponse.from_text(text))
253
258
  return cast(ResultData, text)
254
259
  else:
255
- structured_message = self._stream_response.get(final=True)
256
- self._marked_completed(structured_message=structured_message)
257
- return await self.validate_structured_result(structured_message)
260
+ message = self._stream_response.get(final=True)
261
+ await self._marked_completed(message)
262
+ return await self.validate_structured_result(message)
258
263
 
259
264
  @property
260
265
  def is_structured(self) -> bool:
@@ -274,11 +279,12 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
274
279
  return self._stream_response.timestamp()
275
280
 
276
281
  async def validate_structured_result(
277
- self, message: messages.ModelStructuredResponse, *, allow_partial: bool = False
282
+ self, message: _messages.ModelResponse, *, allow_partial: bool = False
278
283
  ) -> ResultData:
279
284
  """Validate a structured result message."""
280
285
  assert self._result_schema is not None, 'Expected _result_schema to not be None'
281
- match = self._result_schema.find_tool(message)
286
+ assert self._result_tool_name is not None, 'Expected _result_tool_name to not be None'
287
+ match = self._result_schema.find_named_tool(message.parts, self._result_tool_name)
282
288
  if match is None:
283
289
  raise exceptions.UnexpectedModelBehavior(
284
290
  f'Invalid message, unable to find tool: {self._result_schema.tool_names()}'
@@ -288,7 +294,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
288
294
  result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
289
295
 
290
296
  for validator in self._result_validators:
291
- result_data = await validator.validate(result_data, self._deps, 0, call)
297
+ result_data = await validator.validate(result_data, self._deps, 0, call, self._all_messages)
292
298
  return result_data
293
299
 
294
300
  async def _validate_text_result(self, text: str) -> str:
@@ -298,19 +304,11 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
298
304
  self._deps,
299
305
  0,
300
306
  None,
307
+ self._all_messages,
301
308
  )
302
309
  return text
303
310
 
304
- def _marked_completed(
305
- self, *, text: str | None = None, structured_message: messages.ModelStructuredResponse | None = None
306
- ) -> None:
311
+ async def _marked_completed(self, message: _messages.ModelResponse) -> None:
307
312
  self.is_complete = True
308
- if text is not None:
309
- assert structured_message is None, 'Either text or structured_message should provided, not both'
310
- self._all_messages.append(
311
- messages.ModelTextResponse(content=text, timestamp=self._stream_response.timestamp())
312
- )
313
- else:
314
- assert structured_message is not None, 'Either text or structured_message should provided, not both'
315
- self._all_messages.append(structured_message)
316
- self._on_complete(self._all_messages)
313
+ self._all_messages.append(message)
314
+ await self._on_complete()
@@ -0,0 +1,72 @@
1
+ from __future__ import annotations
2
+
3
+ from httpx import Timeout
4
+ from typing_extensions import TypedDict
5
+
6
+
7
+ class ModelSettings(TypedDict, total=False):
8
+ """Settings to configure an LLM.
9
+
10
+ Here we include only settings which apply to multiple models / model providers.
11
+ """
12
+
13
+ max_tokens: int
14
+ """The maximum number of tokens to generate before stopping.
15
+
16
+ Supported by:
17
+ * Gemini
18
+ * Anthropic
19
+ * OpenAI
20
+ * Groq
21
+ """
22
+
23
+ temperature: float
24
+ """Amount of randomness injected into the response.
25
+
26
+ Use `temperature` closer to `0.0` for analytical / multiple choice, and closer to a model's
27
+ maximum `temperature` for creative and generative tasks.
28
+
29
+ Note that even with `temperature` of `0.0`, the results will not be fully deterministic.
30
+
31
+ Supported by:
32
+ * Gemini
33
+ * Anthropic
34
+ * OpenAI
35
+ * Groq
36
+ """
37
+
38
+ top_p: float
39
+ """An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.
40
+
41
+ So 0.1 means only the tokens comprising the top 10% probability mass are considered.
42
+
43
+ You should either alter `temperature` or `top_p`, but not both.
44
+
45
+ Supported by:
46
+ * Gemini
47
+ * Anthropic
48
+ * OpenAI
49
+ * Groq
50
+ """
51
+
52
+ timeout: float | Timeout
53
+ """Override the client-level default timeout for a request, in seconds.
54
+
55
+ Supported by:
56
+ * Gemini
57
+ * Anthropic
58
+ * OpenAI
59
+ * Groq
60
+ """
61
+
62
+
63
+ def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | None) -> ModelSettings | None:
64
+ """Merge two sets of model settings, preferring the overrides.
65
+
66
+ A common use case is: merge_model_settings(<agent settings>, <run settings>)
67
+ """
68
+ # Note: we may want merge recursively if/when we add non-primitive values
69
+ if base and overrides:
70
+ return base | overrides
71
+ else:
72
+ return base or overrides