inspect-ai 0.3.99__py3-none-any.whl → 0.3.100__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_display/core/config.py +11 -5
- inspect_ai/_display/core/panel.py +66 -2
- inspect_ai/_display/core/textual.py +5 -2
- inspect_ai/_display/plain/display.py +1 -0
- inspect_ai/_display/rich/display.py +2 -2
- inspect_ai/_display/textual/widgets/transcript.py +37 -9
- inspect_ai/_eval/score.py +2 -4
- inspect_ai/_eval/task/run.py +59 -81
- inspect_ai/_util/content.py +11 -6
- inspect_ai/_util/interrupt.py +2 -2
- inspect_ai/_util/text.py +7 -0
- inspect_ai/_util/working.py +8 -37
- inspect_ai/_view/__init__.py +0 -0
- inspect_ai/_view/schema.py +2 -1
- inspect_ai/_view/www/CLAUDE.md +15 -0
- inspect_ai/_view/www/dist/assets/index.css +263 -159
- inspect_ai/_view/www/dist/assets/index.js +22153 -19093
- inspect_ai/_view/www/log-schema.json +77 -3
- inspect_ai/_view/www/package.json +5 -1
- inspect_ai/_view/www/src/@types/log.d.ts +9 -0
- inspect_ai/_view/www/src/app/App.tsx +1 -15
- inspect_ai/_view/www/src/app/appearance/icons.ts +4 -1
- inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +24 -6
- inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +0 -5
- inspect_ai/_view/www/src/app/content/RenderedContent.tsx +220 -205
- inspect_ai/_view/www/src/app/log-view/LogViewContainer.tsx +2 -1
- inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +5 -0
- inspect_ai/_view/www/src/app/routing/url.ts +84 -4
- inspect_ai/_view/www/src/app/samples/InlineSampleDisplay.module.css +0 -5
- inspect_ai/_view/www/src/app/samples/SampleDialog.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/SampleDisplay.module.css +7 -0
- inspect_ai/_view/www/src/app/samples/SampleDisplay.tsx +24 -17
- inspect_ai/_view/www/src/app/samples/SampleSummaryView.module.css +1 -2
- inspect_ai/_view/www/src/app/samples/chat/ChatMessage.tsx +8 -6
- inspect_ai/_view/www/src/app/samples/chat/ChatMessageRow.tsx +0 -4
- inspect_ai/_view/www/src/app/samples/chat/ChatViewVirtualList.tsx +3 -2
- inspect_ai/_view/www/src/app/samples/chat/MessageContent.tsx +2 -0
- inspect_ai/_view/www/src/app/samples/chat/MessageContents.tsx +2 -0
- inspect_ai/_view/www/src/app/samples/chat/messages.ts +1 -0
- inspect_ai/_view/www/src/app/samples/chat/tools/ToolCallView.tsx +1 -0
- inspect_ai/_view/www/src/app/samples/list/SampleRow.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/ErrorEventView.tsx +1 -2
- inspect_ai/_view/www/src/app/samples/transcript/InfoEventView.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/InputEventView.tsx +1 -2
- inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/SampleInitEventView.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/SampleLimitEventView.tsx +3 -2
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.tsx +4 -5
- inspect_ai/_view/www/src/app/samples/transcript/ScoreEventView.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +1 -2
- inspect_ai/_view/www/src/app/samples/transcript/StepEventView.tsx +1 -3
- inspect_ai/_view/www/src/app/samples/transcript/SubtaskEventView.tsx +1 -2
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +3 -4
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptPanel.module.css +42 -0
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptPanel.tsx +77 -0
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualList.tsx +27 -71
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +13 -3
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.tsx +27 -2
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.module.css +1 -0
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +21 -22
- inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.module.css +45 -0
- inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.tsx +223 -0
- inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.module.css +10 -0
- inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.tsx +258 -0
- inspect_ai/_view/www/src/app/samples/transcript/outline/tree-visitors.ts +187 -0
- inspect_ai/_view/www/src/app/samples/transcript/state/StateEventRenderers.tsx +8 -1
- inspect_ai/_view/www/src/app/samples/transcript/state/StateEventView.tsx +3 -4
- inspect_ai/_view/www/src/app/samples/transcript/transform/hooks.ts +78 -0
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +340 -135
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +3 -0
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +2 -0
- inspect_ai/_view/www/src/app/types.ts +5 -1
- inspect_ai/_view/www/src/client/api/api-browser.ts +2 -2
- inspect_ai/_view/www/src/components/LiveVirtualList.tsx +6 -1
- inspect_ai/_view/www/src/components/MarkdownDiv.tsx +1 -1
- inspect_ai/_view/www/src/components/PopOver.tsx +422 -0
- inspect_ai/_view/www/src/components/PulsingDots.module.css +9 -9
- inspect_ai/_view/www/src/components/PulsingDots.tsx +4 -1
- inspect_ai/_view/www/src/components/StickyScroll.tsx +183 -0
- inspect_ai/_view/www/src/components/TabSet.tsx +4 -0
- inspect_ai/_view/www/src/state/hooks.ts +52 -2
- inspect_ai/_view/www/src/state/logSlice.ts +4 -3
- inspect_ai/_view/www/src/state/samplePolling.ts +8 -0
- inspect_ai/_view/www/src/state/sampleSlice.ts +53 -9
- inspect_ai/_view/www/src/state/scrolling.ts +152 -0
- inspect_ai/_view/www/src/utils/attachments.ts +7 -0
- inspect_ai/_view/www/src/utils/python.ts +18 -0
- inspect_ai/_view/www/yarn.lock +269 -6
- inspect_ai/agent/_react.py +12 -7
- inspect_ai/agent/_run.py +2 -3
- inspect_ai/analysis/beta/_dataframe/samples/table.py +19 -18
- inspect_ai/log/_log.py +1 -1
- inspect_ai/log/_recorders/file.py +2 -9
- inspect_ai/log/_transcript.py +1 -1
- inspect_ai/model/_call_tools.py +6 -2
- inspect_ai/model/_openai.py +1 -1
- inspect_ai/model/_openai_responses.py +78 -39
- inspect_ai/model/_openai_web_search.py +31 -0
- inspect_ai/model/_providers/azureai.py +72 -3
- inspect_ai/model/_providers/openai.py +2 -1
- inspect_ai/scorer/_metric.py +1 -2
- inspect_ai/solver/_task_state.py +2 -2
- inspect_ai/tool/_tool.py +6 -2
- inspect_ai/tool/_tool_def.py +27 -4
- inspect_ai/tool/_tool_info.py +2 -0
- inspect_ai/tool/_tools/_web_search/_google.py +15 -4
- inspect_ai/tool/_tools/_web_search/_tavily.py +35 -12
- inspect_ai/tool/_tools/_web_search/_web_search.py +214 -45
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_json.py +3 -0
- inspect_ai/util/_limit.py +230 -20
- inspect_ai/util/_sandbox/docker/compose.py +20 -11
- inspect_ai/util/_span.py +1 -1
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/METADATA +3 -3
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/RECORD +120 -106
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/WHEEL +1 -1
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/top_level.txt +0 -0
@@ -1,68 +1,123 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import (
|
2
|
+
Any,
|
3
|
+
Awaitable,
|
4
|
+
Callable,
|
5
|
+
Literal,
|
6
|
+
TypeAlias,
|
7
|
+
TypedDict,
|
8
|
+
get_args,
|
9
|
+
)
|
10
|
+
|
11
|
+
from typing_extensions import Unpack
|
2
12
|
|
3
13
|
from inspect_ai._util.deprecation import deprecation_warning
|
14
|
+
from inspect_ai.tool._tool_def import ToolDef
|
4
15
|
|
5
16
|
from ..._tool import Tool, ToolResult, tool
|
6
|
-
from ._google import
|
7
|
-
from ._tavily import tavily_search_provider
|
17
|
+
from ._google import GoogleOptions, google_search_provider
|
18
|
+
from ._tavily import TavilyOptions, tavily_search_provider
|
19
|
+
|
20
|
+
Provider: TypeAlias = Literal["openai", "tavily", "google"] # , "gemini", "anthropic"
|
21
|
+
valid_providers = set(get_args(Provider))
|
22
|
+
|
23
|
+
|
24
|
+
# It would have been nice if the values below were TypedDicts. The problem is
|
25
|
+
# that if the caller creates a literal dict variable (rather than passing the
|
26
|
+
# dict inline), the type checker will erase the type of the literal to something
|
27
|
+
# that doesn't conform the the required TypedDict when passed. This is lame, but
|
28
|
+
# we'll do runtime validation instead.
|
29
|
+
#
|
30
|
+
# If the caller uses this dict form and uses a value of `None`, it means that
|
31
|
+
# they want to use that provider and to use the default options.
|
32
|
+
class Providers(TypedDict, total=False):
|
33
|
+
google: dict[str, Any] | None
|
34
|
+
tavily: dict[str, Any] | None
|
35
|
+
openai: dict[str, Any] | None
|
36
|
+
|
37
|
+
|
38
|
+
class WebSearchDeprecatedArgs(TypedDict, total=False):
|
39
|
+
provider: Literal["tavily", "google"] | None
|
40
|
+
num_results: int | None
|
41
|
+
max_provider_calls: int | None
|
42
|
+
max_connections: int | None
|
43
|
+
model: str | None
|
8
44
|
|
9
45
|
|
10
46
|
@tool
|
11
47
|
def web_search(
|
12
|
-
|
13
|
-
|
14
|
-
max_provider_calls: int = 3,
|
15
|
-
max_connections: int = 10,
|
16
|
-
model: str | None = None,
|
48
|
+
providers: Provider | Providers | list[Provider | Providers] | None = None,
|
49
|
+
**deprecated: Unpack[WebSearchDeprecatedArgs],
|
17
50
|
) -> Tool:
|
18
51
|
"""Web search tool.
|
19
52
|
|
20
|
-
|
21
|
-
|
22
|
-
`use_tools(web_search(provider="tavily"))`))
|
53
|
+
Web searches are executed using a provider. Providers are split
|
54
|
+
into two categories:
|
23
55
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
56
|
+
- Internal providers: "openai" - these use the model's built-in search
|
57
|
+
capability and do not require separate API keys. These work only for
|
58
|
+
their respective model provider (e.g. the "openai" search provider
|
59
|
+
works only for `openai/*` models).
|
60
|
+
|
61
|
+
- External providers: "tavily" and "google". These are external services
|
62
|
+
that work with any m odel and require separate accounts and API keys.
|
63
|
+
|
64
|
+
Internal providers will be prioritized if running on the corresponding model
|
65
|
+
(e.g., "openai" provider will be used when running on `openai` models). If an
|
66
|
+
internal provider is specified but the evaluation is run with a different
|
67
|
+
model, a fallback external provider must also be specified.
|
29
68
|
|
30
69
|
See further documentation at <https://inspect.aisi.org.uk/tools-standard.html#sec-web-search>.
|
31
70
|
|
32
71
|
Args:
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
provider
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
72
|
+
providers: Configuration for the search providers to use. Currently supported
|
73
|
+
providers are "openai","tavily", and "google", The `providers` parameter
|
74
|
+
supports several formats based on either a `str` specifying a provider or
|
75
|
+
a `dict` whose keys are the provider names and whose values are the
|
76
|
+
provider-specific options. A single value or a list of these can be passed.
|
77
|
+
This arg is optional just for backwards compatibility. New code should
|
78
|
+
always provide this argument.
|
79
|
+
|
80
|
+
Single provider:
|
81
|
+
```
|
82
|
+
web_search("tavily")
|
83
|
+
web_search({"tavily": {"max_results": 5}}) # Tavily-specific options
|
84
|
+
```
|
85
|
+
|
86
|
+
Multiple providers:
|
87
|
+
```
|
88
|
+
# "openai" used for OpenAI models, "tavily" as fallback
|
89
|
+
web_search(["openai", "tavily"])
|
90
|
+
|
91
|
+
# The None value means to use the provider with default options
|
92
|
+
web_search({"openai": None, "tavily": {"max_results": 5}}
|
93
|
+
```
|
94
|
+
|
95
|
+
Mixed format:
|
96
|
+
```
|
97
|
+
web_search(["openai", {"tavily": {"max_results": 5}}])
|
98
|
+
```
|
99
|
+
|
100
|
+
When specified in the `dict` format, the `None` value for a provider means
|
101
|
+
to use the provider with default options.
|
102
|
+
|
103
|
+
Provider-specific options:
|
104
|
+
- openai: Supports OpenAI's web search parameters.
|
105
|
+
See https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses
|
106
|
+
|
107
|
+
- tavily: Supports options like `max_results`, `search_depth`, etc.
|
108
|
+
See https://docs.tavily.com/documentation/api-reference/endpoint/search
|
109
|
+
|
110
|
+
- google: Supports options like `num_results`, `max_provider_calls`,
|
111
|
+
`max_connections`, and `model`
|
112
|
+
|
113
|
+
**deprecated: Deprecated arguments.
|
46
114
|
|
47
115
|
Returns:
|
48
116
|
A tool that can be registered for use by models to search the web.
|
49
117
|
"""
|
50
|
-
|
51
|
-
if maybe_get_google_api_keys():
|
52
|
-
deprecation_warning(
|
53
|
-
"The `google` `web_search` provider was inferred based on the presence of environment variables. Please specify the provider explicitly to avoid this warning."
|
54
|
-
)
|
55
|
-
provider = "google"
|
56
|
-
else:
|
57
|
-
raise ValueError(
|
58
|
-
"Omitting `provider` is no longer supported. Please specify the `web_search` provider explicitly to avoid this error."
|
59
|
-
)
|
118
|
+
normalized_providers = _normalize_config(providers, **deprecated)
|
60
119
|
|
61
|
-
search_provider =
|
62
|
-
google_search_provider(num_results, max_provider_calls, max_connections, model)
|
63
|
-
if provider == "google"
|
64
|
-
else tavily_search_provider(num_results, max_connections)
|
65
|
-
)
|
120
|
+
search_provider: Callable[[str], Awaitable[str | None]] | None = None
|
66
121
|
|
67
122
|
async def execute(query: str) -> ToolResult:
|
68
123
|
"""
|
@@ -71,6 +126,9 @@ def web_search(
|
|
71
126
|
Args:
|
72
127
|
query (str): Search query.
|
73
128
|
"""
|
129
|
+
nonlocal search_provider
|
130
|
+
if not search_provider:
|
131
|
+
search_provider = _create_external_provider(normalized_providers)
|
74
132
|
search_result = await search_provider(query)
|
75
133
|
|
76
134
|
return (
|
@@ -82,4 +140,115 @@ def web_search(
|
|
82
140
|
else ("I'm sorry, I couldn't find any relevant information on the web.")
|
83
141
|
)
|
84
142
|
|
85
|
-
return
|
143
|
+
return ToolDef(
|
144
|
+
execute, name="web_search", options=dict(normalized_providers)
|
145
|
+
).as_tool()
|
146
|
+
|
147
|
+
|
148
|
+
def _normalize_config(
|
149
|
+
providers: Provider | Providers | list[Provider | Providers] | None,
|
150
|
+
**deprecated: Unpack[WebSearchDeprecatedArgs],
|
151
|
+
) -> Providers:
|
152
|
+
"""
|
153
|
+
Deal with breaking changes in the web_search parameter list.
|
154
|
+
|
155
|
+
This function adapts (hopefully) all of the old variants of how the tool
|
156
|
+
factory may have been called converts to the new config format.
|
157
|
+
"""
|
158
|
+
# Cases to handle:
|
159
|
+
# 1. Both deprecated_provider and providers are set
|
160
|
+
# ValueError
|
161
|
+
# 2. Neither deprecated_provider nor providers is set
|
162
|
+
# act as if they passed provider="google"
|
163
|
+
# 3. Only providers is set
|
164
|
+
# if any of the other deprecated parameters is set, then ValueError
|
165
|
+
# else Happy path
|
166
|
+
# 4. Only deprecated_provider is set
|
167
|
+
# convert to new config format - including processing old other params
|
168
|
+
|
169
|
+
deprecated_provider = deprecated.get("provider", None)
|
170
|
+
# Case 1.
|
171
|
+
if deprecated_provider and providers:
|
172
|
+
raise ValueError("`provider` is deprecated. Please only specify `providers`.")
|
173
|
+
|
174
|
+
# Case 2.
|
175
|
+
if providers is None and deprecated_provider is None:
|
176
|
+
deprecated_provider = "google"
|
177
|
+
|
178
|
+
num_results = deprecated.get("num_results", None)
|
179
|
+
max_provider_calls = deprecated.get("max_provider_calls", None)
|
180
|
+
max_connections = deprecated.get("max_connections", None)
|
181
|
+
model = deprecated.get("model", None)
|
182
|
+
|
183
|
+
# Getting here means that we have either a providers or a deprecated_provider
|
184
|
+
if deprecated_provider:
|
185
|
+
return _get_config_via_back_compat(
|
186
|
+
deprecated_provider,
|
187
|
+
num_results=num_results,
|
188
|
+
max_provider_calls=max_provider_calls,
|
189
|
+
max_connections=max_connections,
|
190
|
+
model=model,
|
191
|
+
)
|
192
|
+
|
193
|
+
assert providers, "providers should not be None here"
|
194
|
+
normalized: Providers = {}
|
195
|
+
for entry in providers if isinstance(providers, list) else [providers]:
|
196
|
+
if isinstance(entry, str):
|
197
|
+
if entry not in valid_providers:
|
198
|
+
raise ValueError(f"Invalid provider: '{entry}'")
|
199
|
+
normalized[entry] = None # type: ignore
|
200
|
+
else:
|
201
|
+
for key, value in entry.items():
|
202
|
+
if key not in valid_providers:
|
203
|
+
raise ValueError(f"Invalid provider: '{key}'")
|
204
|
+
normalized[key] = value # type: ignore
|
205
|
+
return normalized
|
206
|
+
|
207
|
+
|
208
|
+
def _get_config_via_back_compat(
|
209
|
+
provider: Literal["tavily", "google"],
|
210
|
+
num_results: int | None,
|
211
|
+
max_provider_calls: int | None,
|
212
|
+
max_connections: int | None,
|
213
|
+
model: str | None,
|
214
|
+
) -> Providers:
|
215
|
+
if (
|
216
|
+
num_results is None
|
217
|
+
and max_provider_calls is None
|
218
|
+
and max_connections is None
|
219
|
+
and model is None
|
220
|
+
):
|
221
|
+
return {"google": None} if provider == "google" else {"tavily": None}
|
222
|
+
|
223
|
+
# If we get here, we have at least one old school parameter
|
224
|
+
deprecation_warning(
|
225
|
+
"The `num_results`, `max_provider_calls`, `max_connections`, and `model` parameters are deprecated. Please use the `config` parameter instead."
|
226
|
+
)
|
227
|
+
|
228
|
+
if provider == "google":
|
229
|
+
return {
|
230
|
+
"google": GoogleOptions(
|
231
|
+
num_results=num_results,
|
232
|
+
max_provider_calls=max_provider_calls,
|
233
|
+
max_connections=max_connections,
|
234
|
+
model=model,
|
235
|
+
).model_dump(exclude_none=True)
|
236
|
+
}
|
237
|
+
else:
|
238
|
+
return {
|
239
|
+
"tavily": TavilyOptions(
|
240
|
+
max_results=num_results, max_connections=max_connections
|
241
|
+
).model_dump(exclude_none=True)
|
242
|
+
}
|
243
|
+
|
244
|
+
|
245
|
+
def _create_external_provider(
|
246
|
+
providers: Providers,
|
247
|
+
) -> Callable[[str], Awaitable[str | None]]:
|
248
|
+
if "tavily" in providers:
|
249
|
+
return tavily_search_provider(providers.get("tavily", None))
|
250
|
+
|
251
|
+
if "google" in providers:
|
252
|
+
return google_search_provider(providers.get("google", None))
|
253
|
+
|
254
|
+
raise ValueError("No valid provider found.")
|
inspect_ai/util/__init__.py
CHANGED
@@ -6,7 +6,9 @@ from inspect_ai.util._limit import (
|
|
6
6
|
LimitScope,
|
7
7
|
apply_limits,
|
8
8
|
message_limit,
|
9
|
+
time_limit,
|
9
10
|
token_limit,
|
11
|
+
working_limit,
|
10
12
|
)
|
11
13
|
|
12
14
|
from ._collect import collect
|
@@ -81,6 +83,8 @@ __all__ = [
|
|
81
83
|
"subtask",
|
82
84
|
"throttle",
|
83
85
|
"token_limit",
|
86
|
+
"time_limit",
|
87
|
+
"working_limit",
|
84
88
|
"trace_action",
|
85
89
|
"trace_message",
|
86
90
|
"RegistryType",
|
inspect_ai/util/_json.py
CHANGED
@@ -3,6 +3,7 @@ import typing
|
|
3
3
|
from copy import deepcopy
|
4
4
|
from dataclasses import is_dataclass
|
5
5
|
from datetime import date, datetime, time
|
6
|
+
from enum import EnumMeta
|
6
7
|
from typing import (
|
7
8
|
Any,
|
8
9
|
Dict,
|
@@ -101,6 +102,8 @@ def json_schema(t: Type[Any]) -> JSONSchema:
|
|
101
102
|
or (isinstance(t, type) and issubclass(t, BaseModel))
|
102
103
|
):
|
103
104
|
return cls_json_schema(t)
|
105
|
+
elif isinstance(t, EnumMeta):
|
106
|
+
return JSONSchema(enum=[item.value for item in t])
|
104
107
|
elif t is type(None):
|
105
108
|
return JSONSchema(type="null")
|
106
109
|
else:
|
inspect_ai/util/_limit.py
CHANGED
@@ -7,6 +7,7 @@ from contextvars import ContextVar
|
|
7
7
|
from types import TracebackType
|
8
8
|
from typing import TYPE_CHECKING, Generic, Iterator, Literal, TypeVar
|
9
9
|
|
10
|
+
import anyio
|
10
11
|
from typing_extensions import Self
|
11
12
|
|
12
13
|
from inspect_ai._util.logger import warn_once
|
@@ -33,22 +34,23 @@ class LimitExceededError(Exception):
|
|
33
34
|
value: Value compared to.
|
34
35
|
limit: Limit applied.
|
35
36
|
message (str | None): Optional. Human readable message.
|
36
|
-
source (Limit | None): Optional. The `Limit` instance which was responsible for
|
37
|
-
raising this error.
|
37
|
+
source (Limit | None): Optional. The `Limit` instance which was responsible for raising this error.
|
38
38
|
"""
|
39
39
|
|
40
40
|
def __init__(
|
41
41
|
self,
|
42
42
|
type: Literal["message", "time", "working", "token", "operator", "custom"],
|
43
43
|
*,
|
44
|
-
value:
|
45
|
-
limit:
|
44
|
+
value: float,
|
45
|
+
limit: float,
|
46
46
|
message: str | None = None,
|
47
47
|
source: Limit | None = None,
|
48
48
|
) -> None:
|
49
49
|
self.type = type
|
50
50
|
self.value = value
|
51
|
+
self.value_str = self._format_float_or_int(value)
|
51
52
|
self.limit = limit
|
53
|
+
self.limit_str = self._format_float_or_int(limit)
|
52
54
|
self.message = f"Exceeded {type} limit: {limit:,}"
|
53
55
|
self.source = source
|
54
56
|
super().__init__(message)
|
@@ -60,6 +62,12 @@ class LimitExceededError(Exception):
|
|
60
62
|
)
|
61
63
|
return self
|
62
64
|
|
65
|
+
def _format_float_or_int(self, value: float | int) -> str:
|
66
|
+
if isinstance(value, int):
|
67
|
+
return f"{value:,}"
|
68
|
+
else:
|
69
|
+
return f"{value:,.2f}"
|
70
|
+
|
63
71
|
|
64
72
|
class Limit(abc.ABC):
|
65
73
|
"""Base class for all limit context managers."""
|
@@ -80,6 +88,12 @@ class Limit(abc.ABC):
|
|
80
88
|
) -> None:
|
81
89
|
pass
|
82
90
|
|
91
|
+
@property
|
92
|
+
@abc.abstractmethod
|
93
|
+
def usage(self) -> float:
|
94
|
+
"""The current usage of the resource being limited."""
|
95
|
+
pass
|
96
|
+
|
83
97
|
def _check_reuse(self) -> None:
|
84
98
|
if self._entered:
|
85
99
|
raise RuntimeError(
|
@@ -112,18 +126,20 @@ def apply_limits(
|
|
112
126
|
False, all `LimitExceededError` exceptions will be allowed to propagate.
|
113
127
|
"""
|
114
128
|
limit_scope = LimitScope()
|
115
|
-
with ExitStack()
|
116
|
-
|
117
|
-
|
118
|
-
|
129
|
+
# Try scope is outside the `with ExitStack()` so that we can catch any errors raised
|
130
|
+
# when exiting it (which will be where time_limit() would raise LimitExceededError).
|
131
|
+
try:
|
132
|
+
with ExitStack() as stack:
|
133
|
+
for limit in limits:
|
134
|
+
stack.enter_context(limit)
|
119
135
|
yield limit_scope
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
136
|
+
except LimitExceededError as e:
|
137
|
+
# If it was not one of the limits we applied.
|
138
|
+
if e.source is None or e.source not in limits:
|
139
|
+
raise
|
140
|
+
limit_scope.limit_error = e
|
141
|
+
if not catch_errors:
|
142
|
+
raise
|
127
143
|
|
128
144
|
|
129
145
|
class LimitScope:
|
@@ -140,8 +156,6 @@ def token_limit(limit: int | None) -> _TokenLimit:
|
|
140
156
|
"""Limits the total number of tokens which can be used.
|
141
157
|
|
142
158
|
The counter starts when the context manager is opened and ends when it is closed.
|
143
|
-
The context manager can be opened multiple times, even in different execution
|
144
|
-
contexts.
|
145
159
|
|
146
160
|
These limits can be stacked.
|
147
161
|
|
@@ -186,8 +200,7 @@ def message_limit(limit: int | None) -> _MessageLimit:
|
|
186
200
|
"""Limits the number of messages in a conversation.
|
187
201
|
|
188
202
|
The total number of messages in the conversation are compared to the limit (not just
|
189
|
-
"new" messages).
|
190
|
-
execution contexts.
|
203
|
+
"new" messages).
|
191
204
|
|
192
205
|
These limits can be stacked.
|
193
206
|
|
@@ -220,6 +233,62 @@ def check_message_limit(count: int, raise_for_equal: bool) -> None:
|
|
220
233
|
node.check(count, raise_for_equal)
|
221
234
|
|
222
235
|
|
236
|
+
def time_limit(limit: float | None) -> _TimeLimit:
|
237
|
+
"""Limits the wall clock time which can elapse.
|
238
|
+
|
239
|
+
The timer starts when the context manager is opened and stops when it is closed.
|
240
|
+
|
241
|
+
These limits can be stacked.
|
242
|
+
|
243
|
+
When a limit is exceeded, the code block is cancelled and a `LimitExceededError` is
|
244
|
+
raised.
|
245
|
+
|
246
|
+
Uses anyio's cancellation scopes meaning that the operations within the context
|
247
|
+
manager block are cancelled if the limit is exceeded. The `LimitExceededError` is
|
248
|
+
therefore raised at the level that the `time_limit()` context manager was opened,
|
249
|
+
not at the level of the operation which caused the limit to be exceeded (e.g. a call
|
250
|
+
to `generate()`). Ensure you handle `LimitExceededError` at the level of opening the context manager.
|
251
|
+
|
252
|
+
Args:
|
253
|
+
limit: The maximum number of seconds that can pass while the context manager is
|
254
|
+
open. A value of None means unlimited time.
|
255
|
+
"""
|
256
|
+
return _TimeLimit(limit)
|
257
|
+
|
258
|
+
|
259
|
+
def working_limit(limit: float | None) -> _WorkingLimit:
|
260
|
+
"""Limits the working time which can elapse.
|
261
|
+
|
262
|
+
Working time is the wall clock time minus any waiting time e.g. waiting before
|
263
|
+
retrying in response to rate limits or waiting on a semaphore.
|
264
|
+
|
265
|
+
The timer starts when the context manager is opened and stops when it is closed.
|
266
|
+
|
267
|
+
These limits can be stacked.
|
268
|
+
|
269
|
+
When a limit is exceeded, a `LimitExceededError` is raised.
|
270
|
+
|
271
|
+
Args:
|
272
|
+
limit: The maximum number of seconds of working that can pass while the context
|
273
|
+
manager is open. A value of None means unlimited time.
|
274
|
+
"""
|
275
|
+
return _WorkingLimit(limit)
|
276
|
+
|
277
|
+
|
278
|
+
def record_waiting_time(waiting_time: float) -> None:
|
279
|
+
node = working_limit_tree.get()
|
280
|
+
if node is None:
|
281
|
+
return
|
282
|
+
node.record_waiting_time(waiting_time)
|
283
|
+
|
284
|
+
|
285
|
+
def check_working_limit() -> None:
|
286
|
+
node = working_limit_tree.get()
|
287
|
+
if node is None:
|
288
|
+
return
|
289
|
+
node.check()
|
290
|
+
|
291
|
+
|
223
292
|
class _Tree(Generic[TNode]):
|
224
293
|
"""A tree data structure of limit nodes.
|
225
294
|
|
@@ -253,6 +322,7 @@ token_limit_tree: _Tree[_TokenLimit] = _Tree("token_limit_tree")
|
|
253
322
|
# Store the message limit leaf node so that we know which limit to check in
|
254
323
|
# check_message_limit().
|
255
324
|
message_limit_tree: _Tree[_MessageLimit] = _Tree("message_limit_tree")
|
325
|
+
working_limit_tree: _Tree[_WorkingLimit] = _Tree("working_limit_tree")
|
256
326
|
|
257
327
|
|
258
328
|
class _Node:
|
@@ -296,6 +366,10 @@ class _TokenLimit(Limit, _Node):
|
|
296
366
|
) -> None:
|
297
367
|
self._pop_and_check_identity(token_limit_tree)
|
298
368
|
|
369
|
+
@property
|
370
|
+
def usage(self) -> float:
|
371
|
+
return self._usage.total_tokens
|
372
|
+
|
299
373
|
@property
|
300
374
|
def limit(self) -> int | None:
|
301
375
|
"""Get the configured token limit value."""
|
@@ -312,7 +386,7 @@ class _TokenLimit(Limit, _Node):
|
|
312
386
|
self._limit = value
|
313
387
|
|
314
388
|
def record(self, usage: ModelUsage) -> None:
|
315
|
-
"""Record model usage for this node and its
|
389
|
+
"""Record model usage for this node and its ancestor nodes."""
|
316
390
|
if self.parent is not None:
|
317
391
|
self.parent.record(usage)
|
318
392
|
self._usage += usage
|
@@ -369,6 +443,13 @@ class _MessageLimit(Limit, _Node):
|
|
369
443
|
) -> None:
|
370
444
|
self._pop_and_check_identity(message_limit_tree)
|
371
445
|
|
446
|
+
@property
|
447
|
+
def usage(self) -> float:
|
448
|
+
raise NotImplementedError(
|
449
|
+
"Retrieving the message count from a limit is not supported. Please query "
|
450
|
+
"the messages property on the task or agent state instead."
|
451
|
+
)
|
452
|
+
|
372
453
|
@property
|
373
454
|
def limit(self) -> int | None:
|
374
455
|
"""Get the configured message limit value."""
|
@@ -414,3 +495,132 @@ class _MessageLimit(Limit, _Node):
|
|
414
495
|
raise ValueError(
|
415
496
|
f"Message limit value must be a non-negative integer or None: {value}"
|
416
497
|
)
|
498
|
+
|
499
|
+
|
500
|
+
class _TimeLimit(Limit):
|
501
|
+
def __init__(self, limit: float | None) -> None:
|
502
|
+
super().__init__()
|
503
|
+
_validate_time_limit("Time", limit)
|
504
|
+
self._limit = limit
|
505
|
+
self._start_time: float | None = None
|
506
|
+
self._end_time: float | None = None
|
507
|
+
|
508
|
+
def __enter__(self) -> Limit:
|
509
|
+
super()._check_reuse()
|
510
|
+
# Unlike the other limits, this one is not stored in a tree. Anyio handles all
|
511
|
+
# of the state.
|
512
|
+
self._cancel_scope = anyio.move_on_after(self._limit)
|
513
|
+
self._cancel_scope.__enter__()
|
514
|
+
self._start_time = anyio.current_time()
|
515
|
+
return self
|
516
|
+
|
517
|
+
def __exit__(
|
518
|
+
self,
|
519
|
+
exc_type: type[BaseException] | None,
|
520
|
+
exc_val: BaseException | None,
|
521
|
+
exc_tb: TracebackType | None,
|
522
|
+
) -> None:
|
523
|
+
from inspect_ai.log._transcript import SampleLimitEvent, transcript
|
524
|
+
|
525
|
+
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
|
526
|
+
self._end_time = anyio.current_time()
|
527
|
+
if self._cancel_scope.cancel_called and self._limit is not None:
|
528
|
+
message = f"Time limit exceeded. limit: {self._limit} seconds"
|
529
|
+
assert self._start_time is not None
|
530
|
+
# Note we've measured the elapsed time independently of anyio's cancel scope
|
531
|
+
# so this is an approximation.
|
532
|
+
time_elapsed = self._end_time - self._start_time
|
533
|
+
transcript()._event(
|
534
|
+
SampleLimitEvent(type="time", message=message, limit=self._limit)
|
535
|
+
)
|
536
|
+
raise LimitExceededError(
|
537
|
+
"time",
|
538
|
+
value=time_elapsed,
|
539
|
+
limit=self._limit,
|
540
|
+
message=message,
|
541
|
+
source=self,
|
542
|
+
) from exc_val
|
543
|
+
|
544
|
+
@property
|
545
|
+
def usage(self) -> float:
|
546
|
+
if self._start_time is None:
|
547
|
+
return 0.0
|
548
|
+
if self._end_time is None:
|
549
|
+
return anyio.current_time() - self._start_time
|
550
|
+
return self._end_time - self._start_time
|
551
|
+
|
552
|
+
|
553
|
+
class _WorkingLimit(Limit, _Node):
|
554
|
+
def __init__(self, limit: float | None) -> None:
|
555
|
+
super().__init__()
|
556
|
+
_validate_time_limit("Working time", limit)
|
557
|
+
self._limit = limit
|
558
|
+
self.parent: _WorkingLimit | None = None
|
559
|
+
self._start_time: float | None = None
|
560
|
+
self._end_time: float | None = None
|
561
|
+
|
562
|
+
def __enter__(self) -> Limit:
|
563
|
+
super()._check_reuse()
|
564
|
+
self._start_time = anyio.current_time()
|
565
|
+
self._waiting_time = 0.0
|
566
|
+
working_limit_tree.push(self)
|
567
|
+
return self
|
568
|
+
|
569
|
+
def __exit__(
|
570
|
+
self,
|
571
|
+
exc_type: type[BaseException] | None,
|
572
|
+
exc_val: BaseException | None,
|
573
|
+
exc_tb: TracebackType | None,
|
574
|
+
) -> None:
|
575
|
+
self._end_time = anyio.current_time()
|
576
|
+
self._pop_and_check_identity(working_limit_tree)
|
577
|
+
|
578
|
+
@property
|
579
|
+
def usage(self) -> float:
|
580
|
+
if self._start_time is None:
|
581
|
+
return 0.0
|
582
|
+
if self._end_time is None:
|
583
|
+
return anyio.current_time() - self._start_time - self._waiting_time
|
584
|
+
return self._end_time - self._start_time - self._waiting_time
|
585
|
+
|
586
|
+
def record_waiting_time(self, waiting_time: float) -> None:
|
587
|
+
"""Record waiting time for this node and its ancestor nodes."""
|
588
|
+
if self.parent is not None:
|
589
|
+
self.parent.record_waiting_time(waiting_time)
|
590
|
+
self._waiting_time += waiting_time
|
591
|
+
|
592
|
+
def check(self) -> None:
|
593
|
+
"""Check if this working time limit or any ancestor limits have been exceeded.
|
594
|
+
|
595
|
+
The checks occur from root to leaf. This is so that if multiple limits are
|
596
|
+
simultaneously exceeded, the outermost (closest to root) one raises the error,
|
597
|
+
preventing certain sub-agent architectures from ending up in an infinite loop.
|
598
|
+
"""
|
599
|
+
if self.parent is not None:
|
600
|
+
self.parent.check()
|
601
|
+
self._check_self()
|
602
|
+
|
603
|
+
def _check_self(self) -> None:
|
604
|
+
from inspect_ai.log._transcript import SampleLimitEvent, transcript
|
605
|
+
|
606
|
+
if self._limit is None:
|
607
|
+
return
|
608
|
+
if self.usage > self._limit:
|
609
|
+
message = f"Working time limit exceeded. limit: {self._limit} seconds"
|
610
|
+
transcript()._event(
|
611
|
+
SampleLimitEvent(type="working", message=message, limit=self._limit)
|
612
|
+
)
|
613
|
+
raise LimitExceededError(
|
614
|
+
"working",
|
615
|
+
value=self.usage,
|
616
|
+
limit=self._limit,
|
617
|
+
message=message,
|
618
|
+
source=self,
|
619
|
+
)
|
620
|
+
|
621
|
+
|
622
|
+
def _validate_time_limit(name: str, value: float | None) -> None:
|
623
|
+
if value is not None and value < 0:
|
624
|
+
raise ValueError(
|
625
|
+
f"{name} limit value must be a non-negative float or None: {value}"
|
626
|
+
)
|