inspect-ai 0.3.99__py3-none-any.whl → 0.3.100__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. inspect_ai/_display/core/config.py +11 -5
  2. inspect_ai/_display/core/panel.py +66 -2
  3. inspect_ai/_display/core/textual.py +5 -2
  4. inspect_ai/_display/plain/display.py +1 -0
  5. inspect_ai/_display/rich/display.py +2 -2
  6. inspect_ai/_display/textual/widgets/transcript.py +37 -9
  7. inspect_ai/_eval/score.py +2 -4
  8. inspect_ai/_eval/task/run.py +59 -81
  9. inspect_ai/_util/content.py +11 -6
  10. inspect_ai/_util/interrupt.py +2 -2
  11. inspect_ai/_util/text.py +7 -0
  12. inspect_ai/_util/working.py +8 -37
  13. inspect_ai/_view/__init__.py +0 -0
  14. inspect_ai/_view/schema.py +2 -1
  15. inspect_ai/_view/www/CLAUDE.md +15 -0
  16. inspect_ai/_view/www/dist/assets/index.css +263 -159
  17. inspect_ai/_view/www/dist/assets/index.js +22153 -19093
  18. inspect_ai/_view/www/log-schema.json +77 -3
  19. inspect_ai/_view/www/package.json +5 -1
  20. inspect_ai/_view/www/src/@types/log.d.ts +9 -0
  21. inspect_ai/_view/www/src/app/App.tsx +1 -15
  22. inspect_ai/_view/www/src/app/appearance/icons.ts +4 -1
  23. inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +24 -6
  24. inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +0 -5
  25. inspect_ai/_view/www/src/app/content/RenderedContent.tsx +220 -205
  26. inspect_ai/_view/www/src/app/log-view/LogViewContainer.tsx +2 -1
  27. inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +5 -0
  28. inspect_ai/_view/www/src/app/routing/url.ts +84 -4
  29. inspect_ai/_view/www/src/app/samples/InlineSampleDisplay.module.css +0 -5
  30. inspect_ai/_view/www/src/app/samples/SampleDialog.module.css +1 -1
  31. inspect_ai/_view/www/src/app/samples/SampleDisplay.module.css +7 -0
  32. inspect_ai/_view/www/src/app/samples/SampleDisplay.tsx +24 -17
  33. inspect_ai/_view/www/src/app/samples/SampleSummaryView.module.css +1 -2
  34. inspect_ai/_view/www/src/app/samples/chat/ChatMessage.tsx +8 -6
  35. inspect_ai/_view/www/src/app/samples/chat/ChatMessageRow.tsx +0 -4
  36. inspect_ai/_view/www/src/app/samples/chat/ChatViewVirtualList.tsx +3 -2
  37. inspect_ai/_view/www/src/app/samples/chat/MessageContent.tsx +2 -0
  38. inspect_ai/_view/www/src/app/samples/chat/MessageContents.tsx +2 -0
  39. inspect_ai/_view/www/src/app/samples/chat/messages.ts +1 -0
  40. inspect_ai/_view/www/src/app/samples/chat/tools/ToolCallView.tsx +1 -0
  41. inspect_ai/_view/www/src/app/samples/list/SampleRow.tsx +1 -1
  42. inspect_ai/_view/www/src/app/samples/transcript/ErrorEventView.tsx +1 -2
  43. inspect_ai/_view/www/src/app/samples/transcript/InfoEventView.tsx +1 -1
  44. inspect_ai/_view/www/src/app/samples/transcript/InputEventView.tsx +1 -2
  45. inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.module.css +1 -1
  46. inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.tsx +1 -1
  47. inspect_ai/_view/www/src/app/samples/transcript/SampleInitEventView.tsx +1 -1
  48. inspect_ai/_view/www/src/app/samples/transcript/SampleLimitEventView.tsx +3 -2
  49. inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.tsx +4 -5
  50. inspect_ai/_view/www/src/app/samples/transcript/ScoreEventView.tsx +1 -1
  51. inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +1 -2
  52. inspect_ai/_view/www/src/app/samples/transcript/StepEventView.tsx +1 -3
  53. inspect_ai/_view/www/src/app/samples/transcript/SubtaskEventView.tsx +1 -2
  54. inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +3 -4
  55. inspect_ai/_view/www/src/app/samples/transcript/TranscriptPanel.module.css +42 -0
  56. inspect_ai/_view/www/src/app/samples/transcript/TranscriptPanel.tsx +77 -0
  57. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualList.tsx +27 -71
  58. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +13 -3
  59. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.tsx +27 -2
  60. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.module.css +1 -0
  61. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +21 -22
  62. inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.module.css +45 -0
  63. inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.tsx +223 -0
  64. inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.module.css +10 -0
  65. inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.tsx +258 -0
  66. inspect_ai/_view/www/src/app/samples/transcript/outline/tree-visitors.ts +187 -0
  67. inspect_ai/_view/www/src/app/samples/transcript/state/StateEventRenderers.tsx +8 -1
  68. inspect_ai/_view/www/src/app/samples/transcript/state/StateEventView.tsx +3 -4
  69. inspect_ai/_view/www/src/app/samples/transcript/transform/hooks.ts +78 -0
  70. inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +340 -135
  71. inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +3 -0
  72. inspect_ai/_view/www/src/app/samples/transcript/types.ts +2 -0
  73. inspect_ai/_view/www/src/app/types.ts +5 -1
  74. inspect_ai/_view/www/src/client/api/api-browser.ts +2 -2
  75. inspect_ai/_view/www/src/components/LiveVirtualList.tsx +6 -1
  76. inspect_ai/_view/www/src/components/MarkdownDiv.tsx +1 -1
  77. inspect_ai/_view/www/src/components/PopOver.tsx +422 -0
  78. inspect_ai/_view/www/src/components/PulsingDots.module.css +9 -9
  79. inspect_ai/_view/www/src/components/PulsingDots.tsx +4 -1
  80. inspect_ai/_view/www/src/components/StickyScroll.tsx +183 -0
  81. inspect_ai/_view/www/src/components/TabSet.tsx +4 -0
  82. inspect_ai/_view/www/src/state/hooks.ts +52 -2
  83. inspect_ai/_view/www/src/state/logSlice.ts +4 -3
  84. inspect_ai/_view/www/src/state/samplePolling.ts +8 -0
  85. inspect_ai/_view/www/src/state/sampleSlice.ts +53 -9
  86. inspect_ai/_view/www/src/state/scrolling.ts +152 -0
  87. inspect_ai/_view/www/src/utils/attachments.ts +7 -0
  88. inspect_ai/_view/www/src/utils/python.ts +18 -0
  89. inspect_ai/_view/www/yarn.lock +269 -6
  90. inspect_ai/agent/_react.py +12 -7
  91. inspect_ai/agent/_run.py +2 -3
  92. inspect_ai/analysis/beta/_dataframe/samples/table.py +19 -18
  93. inspect_ai/log/_log.py +1 -1
  94. inspect_ai/log/_recorders/file.py +2 -9
  95. inspect_ai/log/_transcript.py +1 -1
  96. inspect_ai/model/_call_tools.py +6 -2
  97. inspect_ai/model/_openai.py +1 -1
  98. inspect_ai/model/_openai_responses.py +78 -39
  99. inspect_ai/model/_openai_web_search.py +31 -0
  100. inspect_ai/model/_providers/azureai.py +72 -3
  101. inspect_ai/model/_providers/openai.py +2 -1
  102. inspect_ai/scorer/_metric.py +1 -2
  103. inspect_ai/solver/_task_state.py +2 -2
  104. inspect_ai/tool/_tool.py +6 -2
  105. inspect_ai/tool/_tool_def.py +27 -4
  106. inspect_ai/tool/_tool_info.py +2 -0
  107. inspect_ai/tool/_tools/_web_search/_google.py +15 -4
  108. inspect_ai/tool/_tools/_web_search/_tavily.py +35 -12
  109. inspect_ai/tool/_tools/_web_search/_web_search.py +214 -45
  110. inspect_ai/util/__init__.py +4 -0
  111. inspect_ai/util/_json.py +3 -0
  112. inspect_ai/util/_limit.py +230 -20
  113. inspect_ai/util/_sandbox/docker/compose.py +20 -11
  114. inspect_ai/util/_span.py +1 -1
  115. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/METADATA +3 -3
  116. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/RECORD +120 -106
  117. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/WHEEL +1 -1
  118. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/entry_points.txt +0 -0
  119. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/licenses/LICENSE +0 -0
  120. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/top_level.txt +0 -0
@@ -1,68 +1,123 @@
1
- from typing import Literal
1
+ from typing import (
2
+ Any,
3
+ Awaitable,
4
+ Callable,
5
+ Literal,
6
+ TypeAlias,
7
+ TypedDict,
8
+ get_args,
9
+ )
10
+
11
+ from typing_extensions import Unpack
2
12
 
3
13
  from inspect_ai._util.deprecation import deprecation_warning
14
+ from inspect_ai.tool._tool_def import ToolDef
4
15
 
5
16
  from ..._tool import Tool, ToolResult, tool
6
- from ._google import google_search_provider, maybe_get_google_api_keys
7
- from ._tavily import tavily_search_provider
17
+ from ._google import GoogleOptions, google_search_provider
18
+ from ._tavily import TavilyOptions, tavily_search_provider
19
+
20
+ Provider: TypeAlias = Literal["openai", "tavily", "google"] # , "gemini", "anthropic"
21
+ valid_providers = set(get_args(Provider))
22
+
23
+
24
+ # It would have been nice if the values below were TypedDicts. The problem is
25
+ # that if the caller creates a literal dict variable (rather than passing the
26
+ # dict inline), the type checker will erase the type of the literal to something
27
+ # that doesn't conform the the required TypedDict when passed. This is lame, but
28
+ # we'll do runtime validation instead.
29
+ #
30
+ # If the caller uses this dict form and uses a value of `None`, it means that
31
+ # they want to use that provider and to use the default options.
32
+ class Providers(TypedDict, total=False):
33
+ google: dict[str, Any] | None
34
+ tavily: dict[str, Any] | None
35
+ openai: dict[str, Any] | None
36
+
37
+
38
+ class WebSearchDeprecatedArgs(TypedDict, total=False):
39
+ provider: Literal["tavily", "google"] | None
40
+ num_results: int | None
41
+ max_provider_calls: int | None
42
+ max_connections: int | None
43
+ model: str | None
8
44
 
9
45
 
10
46
  @tool
11
47
  def web_search(
12
- provider: Literal["tavily", "google"] | None = None,
13
- num_results: int = 3,
14
- max_provider_calls: int = 3,
15
- max_connections: int = 10,
16
- model: str | None = None,
48
+ providers: Provider | Providers | list[Provider | Providers] | None = None,
49
+ **deprecated: Unpack[WebSearchDeprecatedArgs],
17
50
  ) -> Tool:
18
51
  """Web search tool.
19
52
 
20
- A tool that can be registered for use by models to search the web. Use
21
- the `use_tools()` solver to make the tool available (e.g.
22
- `use_tools(web_search(provider="tavily"))`))
53
+ Web searches are executed using a provider. Providers are split
54
+ into two categories:
23
55
 
24
- A web search is conducted using the specified provider.
25
- - When using Tavily, all logic for relevance and summarization is handled by
26
- the Tavily API.
27
- - When using Google, the results are parsed for relevance using the specified
28
- model, and the top 'num_results' relevant pages are returned.
56
+ - Internal providers: "openai" - these use the model's built-in search
57
+ capability and do not require separate API keys. These work only for
58
+ their respective model provider (e.g. the "openai" search provider
59
+ works only for `openai/*` models).
60
+
61
+ - External providers: "tavily" and "google". These are external services
62
+ that work with any m odel and require separate accounts and API keys.
63
+
64
+ Internal providers will be prioritized if running on the corresponding model
65
+ (e.g., "openai" provider will be used when running on `openai` models). If an
66
+ internal provider is specified but the evaluation is run with a different
67
+ model, a fallback external provider must also be specified.
29
68
 
30
69
  See further documentation at <https://inspect.aisi.org.uk/tools-standard.html#sec-web-search>.
31
70
 
32
71
  Args:
33
- provider: Search provider to use:
34
- - "tavily": Uses Tavily's Research API.
35
- - "google": Uses Google Custom Search.
36
- Note: The `| None` type is only for backwards compatibility. Passing
37
- `None` is deprecated.
38
- num_results: The number of search result pages used to provide information
39
- back to the model.
40
- max_provider_calls: Maximum number of search calls to make to the search
41
- provider.
42
- max_connections: Maximum number of concurrent connections to API endpoint
43
- of search provider.
44
- model: Model used to parse web pages for relevance - used only by the
45
- `google` provider.
72
+ providers: Configuration for the search providers to use. Currently supported
73
+ providers are "openai","tavily", and "google", The `providers` parameter
74
+ supports several formats based on either a `str` specifying a provider or
75
+ a `dict` whose keys are the provider names and whose values are the
76
+ provider-specific options. A single value or a list of these can be passed.
77
+ This arg is optional just for backwards compatibility. New code should
78
+ always provide this argument.
79
+
80
+ Single provider:
81
+ ```
82
+ web_search("tavily")
83
+ web_search({"tavily": {"max_results": 5}}) # Tavily-specific options
84
+ ```
85
+
86
+ Multiple providers:
87
+ ```
88
+ # "openai" used for OpenAI models, "tavily" as fallback
89
+ web_search(["openai", "tavily"])
90
+
91
+ # The None value means to use the provider with default options
92
+ web_search({"openai": None, "tavily": {"max_results": 5}}
93
+ ```
94
+
95
+ Mixed format:
96
+ ```
97
+ web_search(["openai", {"tavily": {"max_results": 5}}])
98
+ ```
99
+
100
+ When specified in the `dict` format, the `None` value for a provider means
101
+ to use the provider with default options.
102
+
103
+ Provider-specific options:
104
+ - openai: Supports OpenAI's web search parameters.
105
+ See https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses
106
+
107
+ - tavily: Supports options like `max_results`, `search_depth`, etc.
108
+ See https://docs.tavily.com/documentation/api-reference/endpoint/search
109
+
110
+ - google: Supports options like `num_results`, `max_provider_calls`,
111
+ `max_connections`, and `model`
112
+
113
+ **deprecated: Deprecated arguments.
46
114
 
47
115
  Returns:
48
116
  A tool that can be registered for use by models to search the web.
49
117
  """
50
- if provider is None:
51
- if maybe_get_google_api_keys():
52
- deprecation_warning(
53
- "The `google` `web_search` provider was inferred based on the presence of environment variables. Please specify the provider explicitly to avoid this warning."
54
- )
55
- provider = "google"
56
- else:
57
- raise ValueError(
58
- "Omitting `provider` is no longer supported. Please specify the `web_search` provider explicitly to avoid this error."
59
- )
118
+ normalized_providers = _normalize_config(providers, **deprecated)
60
119
 
61
- search_provider = (
62
- google_search_provider(num_results, max_provider_calls, max_connections, model)
63
- if provider == "google"
64
- else tavily_search_provider(num_results, max_connections)
65
- )
120
+ search_provider: Callable[[str], Awaitable[str | None]] | None = None
66
121
 
67
122
  async def execute(query: str) -> ToolResult:
68
123
  """
@@ -71,6 +126,9 @@ def web_search(
71
126
  Args:
72
127
  query (str): Search query.
73
128
  """
129
+ nonlocal search_provider
130
+ if not search_provider:
131
+ search_provider = _create_external_provider(normalized_providers)
74
132
  search_result = await search_provider(query)
75
133
 
76
134
  return (
@@ -82,4 +140,115 @@ def web_search(
82
140
  else ("I'm sorry, I couldn't find any relevant information on the web.")
83
141
  )
84
142
 
85
- return execute
143
+ return ToolDef(
144
+ execute, name="web_search", options=dict(normalized_providers)
145
+ ).as_tool()
146
+
147
+
148
+ def _normalize_config(
149
+ providers: Provider | Providers | list[Provider | Providers] | None,
150
+ **deprecated: Unpack[WebSearchDeprecatedArgs],
151
+ ) -> Providers:
152
+ """
153
+ Deal with breaking changes in the web_search parameter list.
154
+
155
+ This function adapts (hopefully) all of the old variants of how the tool
156
+ factory may have been called converts to the new config format.
157
+ """
158
+ # Cases to handle:
159
+ # 1. Both deprecated_provider and providers are set
160
+ # ValueError
161
+ # 2. Neither deprecated_provider nor providers is set
162
+ # act as if they passed provider="google"
163
+ # 3. Only providers is set
164
+ # if any of the other deprecated parameters is set, then ValueError
165
+ # else Happy path
166
+ # 4. Only deprecated_provider is set
167
+ # convert to new config format - including processing old other params
168
+
169
+ deprecated_provider = deprecated.get("provider", None)
170
+ # Case 1.
171
+ if deprecated_provider and providers:
172
+ raise ValueError("`provider` is deprecated. Please only specify `providers`.")
173
+
174
+ # Case 2.
175
+ if providers is None and deprecated_provider is None:
176
+ deprecated_provider = "google"
177
+
178
+ num_results = deprecated.get("num_results", None)
179
+ max_provider_calls = deprecated.get("max_provider_calls", None)
180
+ max_connections = deprecated.get("max_connections", None)
181
+ model = deprecated.get("model", None)
182
+
183
+ # Getting here means that we have either a providers or a deprecated_provider
184
+ if deprecated_provider:
185
+ return _get_config_via_back_compat(
186
+ deprecated_provider,
187
+ num_results=num_results,
188
+ max_provider_calls=max_provider_calls,
189
+ max_connections=max_connections,
190
+ model=model,
191
+ )
192
+
193
+ assert providers, "providers should not be None here"
194
+ normalized: Providers = {}
195
+ for entry in providers if isinstance(providers, list) else [providers]:
196
+ if isinstance(entry, str):
197
+ if entry not in valid_providers:
198
+ raise ValueError(f"Invalid provider: '{entry}'")
199
+ normalized[entry] = None # type: ignore
200
+ else:
201
+ for key, value in entry.items():
202
+ if key not in valid_providers:
203
+ raise ValueError(f"Invalid provider: '{key}'")
204
+ normalized[key] = value # type: ignore
205
+ return normalized
206
+
207
+
208
+ def _get_config_via_back_compat(
209
+ provider: Literal["tavily", "google"],
210
+ num_results: int | None,
211
+ max_provider_calls: int | None,
212
+ max_connections: int | None,
213
+ model: str | None,
214
+ ) -> Providers:
215
+ if (
216
+ num_results is None
217
+ and max_provider_calls is None
218
+ and max_connections is None
219
+ and model is None
220
+ ):
221
+ return {"google": None} if provider == "google" else {"tavily": None}
222
+
223
+ # If we get here, we have at least one old school parameter
224
+ deprecation_warning(
225
+ "The `num_results`, `max_provider_calls`, `max_connections`, and `model` parameters are deprecated. Please use the `config` parameter instead."
226
+ )
227
+
228
+ if provider == "google":
229
+ return {
230
+ "google": GoogleOptions(
231
+ num_results=num_results,
232
+ max_provider_calls=max_provider_calls,
233
+ max_connections=max_connections,
234
+ model=model,
235
+ ).model_dump(exclude_none=True)
236
+ }
237
+ else:
238
+ return {
239
+ "tavily": TavilyOptions(
240
+ max_results=num_results, max_connections=max_connections
241
+ ).model_dump(exclude_none=True)
242
+ }
243
+
244
+
245
+ def _create_external_provider(
246
+ providers: Providers,
247
+ ) -> Callable[[str], Awaitable[str | None]]:
248
+ if "tavily" in providers:
249
+ return tavily_search_provider(providers.get("tavily", None))
250
+
251
+ if "google" in providers:
252
+ return google_search_provider(providers.get("google", None))
253
+
254
+ raise ValueError("No valid provider found.")
@@ -6,7 +6,9 @@ from inspect_ai.util._limit import (
6
6
  LimitScope,
7
7
  apply_limits,
8
8
  message_limit,
9
+ time_limit,
9
10
  token_limit,
11
+ working_limit,
10
12
  )
11
13
 
12
14
  from ._collect import collect
@@ -81,6 +83,8 @@ __all__ = [
81
83
  "subtask",
82
84
  "throttle",
83
85
  "token_limit",
86
+ "time_limit",
87
+ "working_limit",
84
88
  "trace_action",
85
89
  "trace_message",
86
90
  "RegistryType",
inspect_ai/util/_json.py CHANGED
@@ -3,6 +3,7 @@ import typing
3
3
  from copy import deepcopy
4
4
  from dataclasses import is_dataclass
5
5
  from datetime import date, datetime, time
6
+ from enum import EnumMeta
6
7
  from typing import (
7
8
  Any,
8
9
  Dict,
@@ -101,6 +102,8 @@ def json_schema(t: Type[Any]) -> JSONSchema:
101
102
  or (isinstance(t, type) and issubclass(t, BaseModel))
102
103
  ):
103
104
  return cls_json_schema(t)
105
+ elif isinstance(t, EnumMeta):
106
+ return JSONSchema(enum=[item.value for item in t])
104
107
  elif t is type(None):
105
108
  return JSONSchema(type="null")
106
109
  else:
inspect_ai/util/_limit.py CHANGED
@@ -7,6 +7,7 @@ from contextvars import ContextVar
7
7
  from types import TracebackType
8
8
  from typing import TYPE_CHECKING, Generic, Iterator, Literal, TypeVar
9
9
 
10
+ import anyio
10
11
  from typing_extensions import Self
11
12
 
12
13
  from inspect_ai._util.logger import warn_once
@@ -33,22 +34,23 @@ class LimitExceededError(Exception):
33
34
  value: Value compared to.
34
35
  limit: Limit applied.
35
36
  message (str | None): Optional. Human readable message.
36
- source (Limit | None): Optional. The `Limit` instance which was responsible for
37
- raising this error.
37
+ source (Limit | None): Optional. The `Limit` instance which was responsible for raising this error.
38
38
  """
39
39
 
40
40
  def __init__(
41
41
  self,
42
42
  type: Literal["message", "time", "working", "token", "operator", "custom"],
43
43
  *,
44
- value: int,
45
- limit: int,
44
+ value: float,
45
+ limit: float,
46
46
  message: str | None = None,
47
47
  source: Limit | None = None,
48
48
  ) -> None:
49
49
  self.type = type
50
50
  self.value = value
51
+ self.value_str = self._format_float_or_int(value)
51
52
  self.limit = limit
53
+ self.limit_str = self._format_float_or_int(limit)
52
54
  self.message = f"Exceeded {type} limit: {limit:,}"
53
55
  self.source = source
54
56
  super().__init__(message)
@@ -60,6 +62,12 @@ class LimitExceededError(Exception):
60
62
  )
61
63
  return self
62
64
 
65
+ def _format_float_or_int(self, value: float | int) -> str:
66
+ if isinstance(value, int):
67
+ return f"{value:,}"
68
+ else:
69
+ return f"{value:,.2f}"
70
+
63
71
 
64
72
  class Limit(abc.ABC):
65
73
  """Base class for all limit context managers."""
@@ -80,6 +88,12 @@ class Limit(abc.ABC):
80
88
  ) -> None:
81
89
  pass
82
90
 
91
+ @property
92
+ @abc.abstractmethod
93
+ def usage(self) -> float:
94
+ """The current usage of the resource being limited."""
95
+ pass
96
+
83
97
  def _check_reuse(self) -> None:
84
98
  if self._entered:
85
99
  raise RuntimeError(
@@ -112,18 +126,20 @@ def apply_limits(
112
126
  False, all `LimitExceededError` exceptions will be allowed to propagate.
113
127
  """
114
128
  limit_scope = LimitScope()
115
- with ExitStack() as stack:
116
- for limit in limits:
117
- stack.enter_context(limit)
118
- try:
129
+ # Try scope is outside the `with ExitStack()` so that we can catch any errors raised
130
+ # when exiting it (which will be where time_limit() would raise LimitExceededError).
131
+ try:
132
+ with ExitStack() as stack:
133
+ for limit in limits:
134
+ stack.enter_context(limit)
119
135
  yield limit_scope
120
- except LimitExceededError as e:
121
- # If it was not one of the limits we applied.
122
- if e.source is None or e.source not in limits:
123
- raise
124
- limit_scope.limit_error = e
125
- if not catch_errors:
126
- raise
136
+ except LimitExceededError as e:
137
+ # If it was not one of the limits we applied.
138
+ if e.source is None or e.source not in limits:
139
+ raise
140
+ limit_scope.limit_error = e
141
+ if not catch_errors:
142
+ raise
127
143
 
128
144
 
129
145
  class LimitScope:
@@ -140,8 +156,6 @@ def token_limit(limit: int | None) -> _TokenLimit:
140
156
  """Limits the total number of tokens which can be used.
141
157
 
142
158
  The counter starts when the context manager is opened and ends when it is closed.
143
- The context manager can be opened multiple times, even in different execution
144
- contexts.
145
159
 
146
160
  These limits can be stacked.
147
161
 
@@ -186,8 +200,7 @@ def message_limit(limit: int | None) -> _MessageLimit:
186
200
  """Limits the number of messages in a conversation.
187
201
 
188
202
  The total number of messages in the conversation are compared to the limit (not just
189
- "new" messages). The context manager can be opened multiple times, even in different
190
- execution contexts.
203
+ "new" messages).
191
204
 
192
205
  These limits can be stacked.
193
206
 
@@ -220,6 +233,62 @@ def check_message_limit(count: int, raise_for_equal: bool) -> None:
220
233
  node.check(count, raise_for_equal)
221
234
 
222
235
 
236
+ def time_limit(limit: float | None) -> _TimeLimit:
237
+ """Limits the wall clock time which can elapse.
238
+
239
+ The timer starts when the context manager is opened and stops when it is closed.
240
+
241
+ These limits can be stacked.
242
+
243
+ When a limit is exceeded, the code block is cancelled and a `LimitExceededError` is
244
+ raised.
245
+
246
+ Uses anyio's cancellation scopes meaning that the operations within the context
247
+ manager block are cancelled if the limit is exceeded. The `LimitExceededError` is
248
+ therefore raised at the level that the `time_limit()` context manager was opened,
249
+ not at the level of the operation which caused the limit to be exceeded (e.g. a call
250
+ to `generate()`). Ensure you handle `LimitExceededError` at the level of opening the context manager.
251
+
252
+ Args:
253
+ limit: The maximum number of seconds that can pass while the context manager is
254
+ open. A value of None means unlimited time.
255
+ """
256
+ return _TimeLimit(limit)
257
+
258
+
259
+ def working_limit(limit: float | None) -> _WorkingLimit:
260
+ """Limits the working time which can elapse.
261
+
262
+ Working time is the wall clock time minus any waiting time e.g. waiting before
263
+ retrying in response to rate limits or waiting on a semaphore.
264
+
265
+ The timer starts when the context manager is opened and stops when it is closed.
266
+
267
+ These limits can be stacked.
268
+
269
+ When a limit is exceeded, a `LimitExceededError` is raised.
270
+
271
+ Args:
272
+ limit: The maximum number of seconds of working that can pass while the context
273
+ manager is open. A value of None means unlimited time.
274
+ """
275
+ return _WorkingLimit(limit)
276
+
277
+
278
+ def record_waiting_time(waiting_time: float) -> None:
279
+ node = working_limit_tree.get()
280
+ if node is None:
281
+ return
282
+ node.record_waiting_time(waiting_time)
283
+
284
+
285
+ def check_working_limit() -> None:
286
+ node = working_limit_tree.get()
287
+ if node is None:
288
+ return
289
+ node.check()
290
+
291
+
223
292
  class _Tree(Generic[TNode]):
224
293
  """A tree data structure of limit nodes.
225
294
 
@@ -253,6 +322,7 @@ token_limit_tree: _Tree[_TokenLimit] = _Tree("token_limit_tree")
253
322
  # Store the message limit leaf node so that we know which limit to check in
254
323
  # check_message_limit().
255
324
  message_limit_tree: _Tree[_MessageLimit] = _Tree("message_limit_tree")
325
+ working_limit_tree: _Tree[_WorkingLimit] = _Tree("working_limit_tree")
256
326
 
257
327
 
258
328
  class _Node:
@@ -296,6 +366,10 @@ class _TokenLimit(Limit, _Node):
296
366
  ) -> None:
297
367
  self._pop_and_check_identity(token_limit_tree)
298
368
 
369
+ @property
370
+ def usage(self) -> float:
371
+ return self._usage.total_tokens
372
+
299
373
  @property
300
374
  def limit(self) -> int | None:
301
375
  """Get the configured token limit value."""
@@ -312,7 +386,7 @@ class _TokenLimit(Limit, _Node):
312
386
  self._limit = value
313
387
 
314
388
  def record(self, usage: ModelUsage) -> None:
315
- """Record model usage for this node and its parent nodes."""
389
+ """Record model usage for this node and its ancestor nodes."""
316
390
  if self.parent is not None:
317
391
  self.parent.record(usage)
318
392
  self._usage += usage
@@ -369,6 +443,13 @@ class _MessageLimit(Limit, _Node):
369
443
  ) -> None:
370
444
  self._pop_and_check_identity(message_limit_tree)
371
445
 
446
+ @property
447
+ def usage(self) -> float:
448
+ raise NotImplementedError(
449
+ "Retrieving the message count from a limit is not supported. Please query "
450
+ "the messages property on the task or agent state instead."
451
+ )
452
+
372
453
  @property
373
454
  def limit(self) -> int | None:
374
455
  """Get the configured message limit value."""
@@ -414,3 +495,132 @@ class _MessageLimit(Limit, _Node):
414
495
  raise ValueError(
415
496
  f"Message limit value must be a non-negative integer or None: {value}"
416
497
  )
498
+
499
+
500
+ class _TimeLimit(Limit):
501
+ def __init__(self, limit: float | None) -> None:
502
+ super().__init__()
503
+ _validate_time_limit("Time", limit)
504
+ self._limit = limit
505
+ self._start_time: float | None = None
506
+ self._end_time: float | None = None
507
+
508
+ def __enter__(self) -> Limit:
509
+ super()._check_reuse()
510
+ # Unlike the other limits, this one is not stored in a tree. Anyio handles all
511
+ # of the state.
512
+ self._cancel_scope = anyio.move_on_after(self._limit)
513
+ self._cancel_scope.__enter__()
514
+ self._start_time = anyio.current_time()
515
+ return self
516
+
517
+ def __exit__(
518
+ self,
519
+ exc_type: type[BaseException] | None,
520
+ exc_val: BaseException | None,
521
+ exc_tb: TracebackType | None,
522
+ ) -> None:
523
+ from inspect_ai.log._transcript import SampleLimitEvent, transcript
524
+
525
+ self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
526
+ self._end_time = anyio.current_time()
527
+ if self._cancel_scope.cancel_called and self._limit is not None:
528
+ message = f"Time limit exceeded. limit: {self._limit} seconds"
529
+ assert self._start_time is not None
530
+ # Note we've measured the elapsed time independently of anyio's cancel scope
531
+ # so this is an approximation.
532
+ time_elapsed = self._end_time - self._start_time
533
+ transcript()._event(
534
+ SampleLimitEvent(type="time", message=message, limit=self._limit)
535
+ )
536
+ raise LimitExceededError(
537
+ "time",
538
+ value=time_elapsed,
539
+ limit=self._limit,
540
+ message=message,
541
+ source=self,
542
+ ) from exc_val
543
+
544
+ @property
545
+ def usage(self) -> float:
546
+ if self._start_time is None:
547
+ return 0.0
548
+ if self._end_time is None:
549
+ return anyio.current_time() - self._start_time
550
+ return self._end_time - self._start_time
551
+
552
+
553
+ class _WorkingLimit(Limit, _Node):
554
+ def __init__(self, limit: float | None) -> None:
555
+ super().__init__()
556
+ _validate_time_limit("Working time", limit)
557
+ self._limit = limit
558
+ self.parent: _WorkingLimit | None = None
559
+ self._start_time: float | None = None
560
+ self._end_time: float | None = None
561
+
562
+ def __enter__(self) -> Limit:
563
+ super()._check_reuse()
564
+ self._start_time = anyio.current_time()
565
+ self._waiting_time = 0.0
566
+ working_limit_tree.push(self)
567
+ return self
568
+
569
+ def __exit__(
570
+ self,
571
+ exc_type: type[BaseException] | None,
572
+ exc_val: BaseException | None,
573
+ exc_tb: TracebackType | None,
574
+ ) -> None:
575
+ self._end_time = anyio.current_time()
576
+ self._pop_and_check_identity(working_limit_tree)
577
+
578
+ @property
579
+ def usage(self) -> float:
580
+ if self._start_time is None:
581
+ return 0.0
582
+ if self._end_time is None:
583
+ return anyio.current_time() - self._start_time - self._waiting_time
584
+ return self._end_time - self._start_time - self._waiting_time
585
+
586
+ def record_waiting_time(self, waiting_time: float) -> None:
587
+ """Record waiting time for this node and its ancestor nodes."""
588
+ if self.parent is not None:
589
+ self.parent.record_waiting_time(waiting_time)
590
+ self._waiting_time += waiting_time
591
+
592
+ def check(self) -> None:
593
+ """Check if this working time limit or any ancestor limits have been exceeded.
594
+
595
+ The checks occur from root to leaf. This is so that if multiple limits are
596
+ simultaneously exceeded, the outermost (closest to root) one raises the error,
597
+ preventing certain sub-agent architectures from ending up in an infinite loop.
598
+ """
599
+ if self.parent is not None:
600
+ self.parent.check()
601
+ self._check_self()
602
+
603
+ def _check_self(self) -> None:
604
+ from inspect_ai.log._transcript import SampleLimitEvent, transcript
605
+
606
+ if self._limit is None:
607
+ return
608
+ if self.usage > self._limit:
609
+ message = f"Working time limit exceeded. limit: {self._limit} seconds"
610
+ transcript()._event(
611
+ SampleLimitEvent(type="working", message=message, limit=self._limit)
612
+ )
613
+ raise LimitExceededError(
614
+ "working",
615
+ value=self.usage,
616
+ limit=self._limit,
617
+ message=message,
618
+ source=self,
619
+ )
620
+
621
+
622
+ def _validate_time_limit(name: str, value: float | None) -> None:
623
+ if value is not None and value < 0:
624
+ raise ValueError(
625
+ f"{name} limit value must be a non-negative float or None: {value}"
626
+ )