inspect-ai 0.3.92__py3-none-any.whl → 0.3.94__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +27 -0
- inspect_ai/_display/textual/widgets/samples.py +3 -3
- inspect_ai/_display/textual/widgets/transcript.py +3 -29
- inspect_ai/_eval/eval.py +19 -2
- inspect_ai/_eval/evalset.py +4 -1
- inspect_ai/_eval/run.py +41 -0
- inspect_ai/_eval/task/generate.py +38 -44
- inspect_ai/_eval/task/log.py +26 -28
- inspect_ai/_eval/task/run.py +23 -27
- inspect_ai/_util/answer.py +26 -0
- inspect_ai/_util/constants.py +0 -1
- inspect_ai/_util/local_server.py +398 -0
- inspect_ai/_util/working.py +10 -4
- inspect_ai/_view/www/dist/assets/index.css +173 -159
- inspect_ai/_view/www/dist/assets/index.js +1417 -1142
- inspect_ai/_view/www/log-schema.json +379 -3
- inspect_ai/_view/www/package.json +1 -1
- inspect_ai/_view/www/src/@types/log.d.ts +93 -14
- inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +2 -2
- inspect_ai/_view/www/src/app/content/MetaDataView.module.css +1 -1
- inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +1 -1
- inspect_ai/_view/www/src/app/content/RenderedContent.tsx +1 -1
- inspect_ai/_view/www/src/app/log-view/LogView.tsx +11 -0
- inspect_ai/_view/www/src/app/log-view/tabs/InfoTab.tsx +2 -9
- inspect_ai/_view/www/src/app/log-view/tabs/ModelsTab.tsx +51 -0
- inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.module.css +6 -0
- inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.tsx +143 -0
- inspect_ai/_view/www/src/app/plan/ModelCard.tsx +1 -2
- inspect_ai/_view/www/src/app/plan/PlanCard.tsx +29 -7
- inspect_ai/_view/www/src/app/plan/PlanDetailView.module.css +1 -1
- inspect_ai/_view/www/src/app/plan/PlanDetailView.tsx +1 -198
- inspect_ai/_view/www/src/app/samples/descriptor/score/NumericScoreDescriptor.tsx +2 -1
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
- inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
- inspect_ai/_view/www/src/app/usage/ModelUsagePanel.tsx +3 -2
- inspect_ai/_view/www/src/app/usage/TokenTable.module.css +4 -1
- inspect_ai/_view/www/src/app/usage/TokenTable.tsx +2 -2
- inspect_ai/_view/www/src/app/usage/UsageCard.module.css +8 -3
- inspect_ai/_view/www/src/app/usage/UsageCard.tsx +1 -35
- inspect_ai/_view/www/src/components/Card.css +0 -1
- inspect_ai/_view/www/src/constants.ts +2 -0
- inspect_ai/_view/www/src/utils/numeric.ts +17 -0
- inspect_ai/agent/_agent.py +3 -3
- inspect_ai/agent/_as_solver.py +22 -12
- inspect_ai/agent/_as_tool.py +20 -6
- inspect_ai/agent/_handoff.py +12 -1
- inspect_ai/agent/_react.py +4 -3
- inspect_ai/agent/_run.py +16 -3
- inspect_ai/agent/_types.py +9 -0
- inspect_ai/dataset/_dataset.py +6 -3
- inspect_ai/log/__init__.py +14 -0
- inspect_ai/log/_convert.py +4 -9
- inspect_ai/log/_file.py +56 -0
- inspect_ai/log/_log.py +99 -0
- inspect_ai/log/_recorders/__init__.py +2 -0
- inspect_ai/log/_recorders/buffer/database.py +12 -11
- inspect_ai/log/_recorders/buffer/filestore.py +2 -2
- inspect_ai/log/_recorders/buffer/types.py +2 -2
- inspect_ai/log/_recorders/eval.py +20 -65
- inspect_ai/log/_recorders/file.py +28 -6
- inspect_ai/log/_recorders/recorder.py +7 -0
- inspect_ai/log/_recorders/types.py +1 -23
- inspect_ai/log/_samples.py +14 -25
- inspect_ai/log/_transcript.py +84 -36
- inspect_ai/log/_tree.py +118 -0
- inspect_ai/log/_util.py +52 -0
- inspect_ai/model/__init__.py +5 -1
- inspect_ai/model/_call_tools.py +72 -44
- inspect_ai/model/_generate_config.py +14 -8
- inspect_ai/model/_model.py +66 -88
- inspect_ai/model/_model_output.py +25 -0
- inspect_ai/model/_openai.py +2 -0
- inspect_ai/model/_providers/anthropic.py +13 -23
- inspect_ai/model/_providers/hf.py +27 -1
- inspect_ai/model/_providers/openai_o1.py +8 -2
- inspect_ai/model/_providers/providers.py +18 -4
- inspect_ai/model/_providers/sglang.py +247 -0
- inspect_ai/model/_providers/vllm.py +211 -400
- inspect_ai/scorer/_choice.py +1 -2
- inspect_ai/solver/__init__.py +7 -2
- inspect_ai/solver/_basic_agent.py +3 -10
- inspect_ai/solver/_chain.py +1 -1
- inspect_ai/solver/_fork.py +1 -1
- inspect_ai/solver/_multiple_choice.py +5 -22
- inspect_ai/solver/_plan.py +2 -2
- inspect_ai/solver/_task_state.py +26 -88
- inspect_ai/solver/_transcript.py +6 -7
- inspect_ai/tool/_json_rpc_helpers.py +45 -17
- inspect_ai/tool/_mcp/_mcp.py +8 -5
- inspect_ai/tool/_mcp/_sandbox.py +8 -2
- inspect_ai/tool/_mcp/server.py +3 -1
- inspect_ai/tool/_tool_call.py +4 -1
- inspect_ai/tool/_tool_support_helpers.py +51 -12
- inspect_ai/tool/_tools/_bash_session.py +190 -68
- inspect_ai/tool/_tools/_computer/_computer.py +25 -1
- inspect_ai/tool/_tools/_execute.py +4 -1
- inspect_ai/tool/_tools/_text_editor.py +4 -3
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +10 -3
- inspect_ai/util/__init__.py +16 -0
- inspect_ai/util/_anyio.py +11 -0
- inspect_ai/util/_collect.py +50 -0
- inspect_ai/util/_limit.py +393 -0
- inspect_ai/util/_limited_conversation.py +57 -0
- inspect_ai/util/_span.py +58 -0
- inspect_ai/util/_subtask.py +27 -42
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/RECORD +120 -134
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/WHEEL +1 -1
- inspect_ai/_display/core/group.py +0 -79
- inspect_ai/solver/_limit.py +0 -39
- inspect_ai/tool/_tools/_computer/_resources/Dockerfile +0 -102
- inspect_ai/tool/_tools/_computer/_resources/README.md +0 -30
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/entrypoint.sh +0 -18
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/novnc_startup.sh +0 -20
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/x11vnc_startup.sh +0 -48
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/xfce_startup.sh +0 -13
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/xvfb_startup.sh +0 -48
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/globalStorage/state.vscdb +0 -0
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/settings.json +0 -9
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-panel.xml +0 -61
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfwm4.xml +0 -91
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Terminal.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/tool/.pylintrc +0 -8
- inspect_ai/tool/_tools/_computer/_resources/tool/.vscode/settings.json +0 -12
- inspect_ai/tool/_tools/_computer/_resources/tool/_args.py +0 -78
- inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +0 -22
- inspect_ai/tool/_tools/_computer/_resources/tool/_logger.py +0 -22
- inspect_ai/tool/_tools/_computer/_resources/tool/_run.py +0 -42
- inspect_ai/tool/_tools/_computer/_resources/tool/_tool_result.py +0 -33
- inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +0 -341
- inspect_ai/tool/_tools/_computer/_resources/tool/computer_tool.py +0 -141
- inspect_ai/tool/_tools/_computer/_resources/tool/pyproject.toml +0 -65
- inspect_ai/tool/_tools/_computer/_resources/tool/requirements.txt +0 -0
- inspect_ai/tool/_tools/_computer/test_args.py +0 -151
- /inspect_ai/{tool/_tools/_computer/_resources/tool/__init__.py → _view/www/src/app/log-view/tabs/ModelsTab.module.css} +0 -0
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,50 @@
|
|
1
|
+
import sys
|
2
|
+
from typing import Awaitable, TypeVar, cast
|
3
|
+
|
4
|
+
import anyio
|
5
|
+
|
6
|
+
from ._span import span
|
7
|
+
|
8
|
+
if sys.version_info < (3, 11):
|
9
|
+
from exceptiongroup import ExceptionGroup
|
10
|
+
|
11
|
+
|
12
|
+
T = TypeVar("T")
|
13
|
+
|
14
|
+
|
15
|
+
async def collect(*tasks: Awaitable[T]) -> list[T]:
|
16
|
+
"""Run and collect the results of one or more async coroutines.
|
17
|
+
|
18
|
+
Similar to [`asyncio.gather()`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather),
|
19
|
+
but also works when [Trio](https://trio.readthedocs.io/en/stable/) is the async backend.
|
20
|
+
|
21
|
+
Automatically includes each task in a `span()`, which
|
22
|
+
ensures that its events are grouped together in the transcript.
|
23
|
+
|
24
|
+
Using `collect()` in preference to `asyncio.gather()` is highly recommended
|
25
|
+
for both Trio compatibility and more legible transcript output.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
*tasks: Tasks to run
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
List of task results.
|
32
|
+
"""
|
33
|
+
results: list[None | T] = [None] * len(tasks)
|
34
|
+
|
35
|
+
try:
|
36
|
+
async with anyio.create_task_group() as tg:
|
37
|
+
|
38
|
+
async def run_task(index: int, task: Awaitable[T]) -> None:
|
39
|
+
async with span(f"task-{index + 1}", type="task"):
|
40
|
+
results[index] = await task
|
41
|
+
|
42
|
+
for i, task in enumerate(tasks):
|
43
|
+
tg.start_soon(run_task, i, task)
|
44
|
+
except ExceptionGroup as ex:
|
45
|
+
if len(ex.exceptions) == 1:
|
46
|
+
raise ex.exceptions[0] from None
|
47
|
+
else:
|
48
|
+
raise
|
49
|
+
|
50
|
+
return cast(list[T], results)
|
@@ -0,0 +1,393 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import abc
|
4
|
+
import logging
|
5
|
+
from contextlib import ExitStack, contextmanager
|
6
|
+
from contextvars import ContextVar
|
7
|
+
from types import TracebackType
|
8
|
+
from typing import TYPE_CHECKING, Iterator, Literal
|
9
|
+
|
10
|
+
from inspect_ai._util.logger import warn_once
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
# These imports are used as type hints only - prevent circular imports.
|
14
|
+
from inspect_ai.model._model_output import ModelUsage
|
15
|
+
from inspect_ai.solver._task_state import TaskState
|
16
|
+
|
17
|
+
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
# Stores the current execution context's leaf _TokenLimitNode.
|
21
|
+
# The resulting data structure is a tree of _TokenLimitNode nodes which each
|
22
|
+
# have a pointer to their parent node. Each additional context manager inserts a new
|
23
|
+
# child node into the tree. The fact that there can be multiple execution contexts is
|
24
|
+
# what makes this a tree rather than a stack.
|
25
|
+
token_limit_leaf_node: ContextVar[_TokenLimitNode | None] = ContextVar(
|
26
|
+
"token_limit_leaf_node", default=None
|
27
|
+
)
|
28
|
+
message_limit_leaf_node: ContextVar[_MessageLimitNode | None] = ContextVar(
|
29
|
+
"message_limit_leaf_node", default=None
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
class LimitExceededError(Exception):
|
34
|
+
"""Exception raised when a limit is exceeded.
|
35
|
+
|
36
|
+
In some scenarios this error may be raised when `value >= limit` to
|
37
|
+
prevent another operation which is guaranteed to exceed the limit from being
|
38
|
+
wastefully performed.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
type: Type of limit exceeded.
|
42
|
+
value: Value compared to.
|
43
|
+
limit: Limit applied.
|
44
|
+
message (str | None): Optional. Human readable message.
|
45
|
+
"""
|
46
|
+
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
type: Literal["message", "time", "working", "token", "operator", "custom"],
|
50
|
+
*,
|
51
|
+
value: int,
|
52
|
+
limit: int,
|
53
|
+
message: str | None = None,
|
54
|
+
) -> None:
|
55
|
+
self.type = type
|
56
|
+
self.value = value
|
57
|
+
self.limit = limit
|
58
|
+
self.message = f"Exceeded {type} limit: {limit:,}"
|
59
|
+
super().__init__(message)
|
60
|
+
|
61
|
+
def with_state(self, state: TaskState) -> LimitExceededError:
|
62
|
+
warn_once(
|
63
|
+
logger,
|
64
|
+
"LimitExceededError.with_state() is deprecated (no longer required).",
|
65
|
+
)
|
66
|
+
return self
|
67
|
+
|
68
|
+
|
69
|
+
class Limit(abc.ABC):
|
70
|
+
"""Base class for all limits."""
|
71
|
+
|
72
|
+
@abc.abstractmethod
|
73
|
+
def __enter__(self) -> Limit:
|
74
|
+
pass
|
75
|
+
|
76
|
+
@abc.abstractmethod
|
77
|
+
def __exit__(
|
78
|
+
self,
|
79
|
+
exc_type: type[BaseException] | None,
|
80
|
+
exc_val: BaseException | None,
|
81
|
+
exc_tb: TracebackType | None,
|
82
|
+
) -> None:
|
83
|
+
pass
|
84
|
+
|
85
|
+
|
86
|
+
@contextmanager
|
87
|
+
def apply_limits(limits: list[Limit]) -> Iterator[None]:
|
88
|
+
"""
|
89
|
+
Apply a list of limits within a context manager.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
limits: List of limits to apply while the context manager is open. Should a
|
93
|
+
limit be exceeded, a LimitExceededError is raised.
|
94
|
+
"""
|
95
|
+
with ExitStack() as stack:
|
96
|
+
for limit in limits:
|
97
|
+
stack.enter_context(limit)
|
98
|
+
yield
|
99
|
+
|
100
|
+
|
101
|
+
def token_limit(limit: int | None) -> _TokenLimit:
|
102
|
+
"""Limits the total number of tokens which can be used.
|
103
|
+
|
104
|
+
The counter starts when the context manager is opened and ends when it is closed.
|
105
|
+
The context manager can be opened multiple times, even in different execution
|
106
|
+
contexts.
|
107
|
+
|
108
|
+
These limits can be stacked.
|
109
|
+
|
110
|
+
This relies on "cooperative" checking - consumers must call check_token_limit()
|
111
|
+
themselves whenever tokens are consumed.
|
112
|
+
|
113
|
+
When a limit is exceeded, a LimitExceededError is raised.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
limit: The maximum number of tokens that can be used while the context manager is
|
117
|
+
open. Tokens used before the context manager was opened are not counted. A value
|
118
|
+
of None means unlimited tokens.
|
119
|
+
"""
|
120
|
+
return _TokenLimit(limit)
|
121
|
+
|
122
|
+
|
123
|
+
def record_model_usage(usage: ModelUsage) -> None:
|
124
|
+
"""Record model usage against any active token limits.
|
125
|
+
|
126
|
+
Does not check if the limit has been exceeded.
|
127
|
+
"""
|
128
|
+
node = token_limit_leaf_node.get()
|
129
|
+
if node is None:
|
130
|
+
return
|
131
|
+
node.record(usage)
|
132
|
+
|
133
|
+
|
134
|
+
def check_token_limit() -> None:
|
135
|
+
"""Check if the current token usage exceeds _any_ of the token limits.
|
136
|
+
|
137
|
+
Within the current execution context (e.g. async task) and its parent contexts only.
|
138
|
+
|
139
|
+
Note that all active token limits are checked, not just the most recent one.
|
140
|
+
"""
|
141
|
+
node = token_limit_leaf_node.get()
|
142
|
+
if node is None:
|
143
|
+
return
|
144
|
+
node.check()
|
145
|
+
|
146
|
+
|
147
|
+
def message_limit(limit: int | None) -> _MessageLimit:
|
148
|
+
"""Limits the number of messages in a conversation.
|
149
|
+
|
150
|
+
The total number of messages in the conversation are compared to the limit (not just
|
151
|
+
"new" messages). The context manager can be opened multiple times, even in different
|
152
|
+
execution contexts.
|
153
|
+
|
154
|
+
These limits can be stacked.
|
155
|
+
|
156
|
+
This relies on "cooperative" checking - consumers must call check_message_limit()
|
157
|
+
themselves whenever the message count is updated.
|
158
|
+
|
159
|
+
When a limit is exceeded, a LimitExceededError is raised.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
limit: The maximum conversation length (number of messages) allowed while the
|
163
|
+
context manager is open. A value of None means unlimited messages.
|
164
|
+
"""
|
165
|
+
return _MessageLimit(limit)
|
166
|
+
|
167
|
+
|
168
|
+
def check_message_limit(count: int, raise_for_equal: bool) -> None:
|
169
|
+
"""Check if the current message count exceeds the active message limit.
|
170
|
+
|
171
|
+
Only the most recent message limit is checked. Ancestors are not checked.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
count: The number of messages in the conversation.
|
175
|
+
raise_for_equal: If True, raise an error if the message count is equal to the
|
176
|
+
limit, otherwise, only raise an error if the message count is greater than the
|
177
|
+
limit.
|
178
|
+
"""
|
179
|
+
node = message_limit_leaf_node.get()
|
180
|
+
if node is None:
|
181
|
+
return
|
182
|
+
node.check(count, raise_for_equal)
|
183
|
+
|
184
|
+
|
185
|
+
class _LimitValueWrapper:
|
186
|
+
"""Container/wrapper type for the limit value.
|
187
|
+
|
188
|
+
This facilitates updating the limit value, which may have been passed to many
|
189
|
+
_TokenLimitNode instances.
|
190
|
+
"""
|
191
|
+
|
192
|
+
def __init__(self, value: int | None) -> None:
|
193
|
+
self.value = value
|
194
|
+
|
195
|
+
|
196
|
+
class _TokenLimit(Limit):
|
197
|
+
def __init__(self, limit: int | None) -> None:
|
198
|
+
self._validate_token_limit(limit)
|
199
|
+
self._limit_value_wrapper = _LimitValueWrapper(limit)
|
200
|
+
|
201
|
+
def __enter__(self) -> Limit:
|
202
|
+
current_node = token_limit_leaf_node.get()
|
203
|
+
new_node = _TokenLimitNode(self._limit_value_wrapper, current_node)
|
204
|
+
# Note that we don't store new_node as an instance variable, because the context
|
205
|
+
# manager may be used across multiple execution contexts, or opened multiple
|
206
|
+
# times.
|
207
|
+
token_limit_leaf_node.set(new_node)
|
208
|
+
return self
|
209
|
+
|
210
|
+
def __exit__(
|
211
|
+
self,
|
212
|
+
exc_type: type[BaseException] | None,
|
213
|
+
exc_val: BaseException | None,
|
214
|
+
exc_tb: TracebackType | None,
|
215
|
+
) -> None:
|
216
|
+
current_node = token_limit_leaf_node.get()
|
217
|
+
assert current_node is not None, (
|
218
|
+
"Token limit node should not be None when exiting context manager."
|
219
|
+
)
|
220
|
+
token_limit_leaf_node.set(current_node.parent)
|
221
|
+
|
222
|
+
@property
|
223
|
+
def limit(self) -> int | None:
|
224
|
+
"""Get the configured token limit value."""
|
225
|
+
return self._limit_value_wrapper.value
|
226
|
+
|
227
|
+
@limit.setter
|
228
|
+
def limit(self, value: int | None) -> None:
|
229
|
+
"""Update the token limit value.
|
230
|
+
|
231
|
+
This will affect the limit for all active token limit nodes derived from this
|
232
|
+
context manager.
|
233
|
+
|
234
|
+
This does not trigger a check of the token limit (which could now have been
|
235
|
+
exceeded).
|
236
|
+
"""
|
237
|
+
self._validate_token_limit(value)
|
238
|
+
self._limit_value_wrapper.value = value
|
239
|
+
|
240
|
+
def _validate_token_limit(self, value: int | None) -> None:
|
241
|
+
if value is not None and value < 0:
|
242
|
+
raise ValueError("Token limit value must be a non-negative integer.")
|
243
|
+
|
244
|
+
|
245
|
+
class _TokenLimitNode:
|
246
|
+
def __init__(
|
247
|
+
self,
|
248
|
+
limit: _LimitValueWrapper,
|
249
|
+
parent: _TokenLimitNode | None,
|
250
|
+
) -> None:
|
251
|
+
"""
|
252
|
+
Initialize a token limit node.
|
253
|
+
|
254
|
+
Forms part of a tree structure. Each node has a pointer to its parent, or None
|
255
|
+
if it is the root node.
|
256
|
+
|
257
|
+
Tracks the token usage for this node and its parent nodes and checks if the
|
258
|
+
usage has exceeded a (variable) limit.
|
259
|
+
|
260
|
+
Args:
|
261
|
+
limit: The maximum number of tokens that can be used while the context
|
262
|
+
manager is open.
|
263
|
+
parent: The parent node in the tree.
|
264
|
+
"""
|
265
|
+
from inspect_ai.model._model_output import ModelUsage
|
266
|
+
|
267
|
+
self._limit = limit
|
268
|
+
self.parent = parent
|
269
|
+
self._usage = ModelUsage()
|
270
|
+
|
271
|
+
def record(self, usage: ModelUsage) -> None:
|
272
|
+
"""Record model usage for this node and its parent nodes."""
|
273
|
+
if self.parent is not None:
|
274
|
+
self.parent.record(usage)
|
275
|
+
self._usage += usage
|
276
|
+
|
277
|
+
def check(self) -> None:
|
278
|
+
"""Check if this token limit or any parent limits have been exceeded."""
|
279
|
+
self._check_self()
|
280
|
+
if self.parent is not None:
|
281
|
+
self.parent.check()
|
282
|
+
|
283
|
+
def _check_self(self) -> None:
|
284
|
+
from inspect_ai.log._transcript import SampleLimitEvent, transcript
|
285
|
+
|
286
|
+
if self._limit.value is None:
|
287
|
+
return
|
288
|
+
total = self._usage.total_tokens
|
289
|
+
if total > self._limit.value:
|
290
|
+
message = (
|
291
|
+
f"Token limit exceeded. value: {total:,}; limit: {self._limit.value:,}"
|
292
|
+
)
|
293
|
+
transcript()._event(
|
294
|
+
SampleLimitEvent(type="token", limit=self._limit.value, message=message)
|
295
|
+
)
|
296
|
+
raise LimitExceededError(
|
297
|
+
"token", value=total, limit=self._limit.value, message=message
|
298
|
+
)
|
299
|
+
|
300
|
+
|
301
|
+
class _MessageLimit(Limit):
|
302
|
+
def __init__(self, limit: int | None) -> None:
|
303
|
+
self._validate_message_limit(limit)
|
304
|
+
self._limit_value_wrapper = _LimitValueWrapper(limit)
|
305
|
+
|
306
|
+
def __enter__(self) -> Limit:
|
307
|
+
current_node = message_limit_leaf_node.get()
|
308
|
+
new_node = _MessageLimitNode(self._limit_value_wrapper, current_node)
|
309
|
+
# Note that we don't store new_node as an instance variable, because the context
|
310
|
+
# manager may be used across multiple execution contexts, or opened multiple
|
311
|
+
# times.
|
312
|
+
message_limit_leaf_node.set(new_node)
|
313
|
+
return self
|
314
|
+
|
315
|
+
def __exit__(
|
316
|
+
self,
|
317
|
+
exc_type: type[BaseException] | None,
|
318
|
+
exc_val: BaseException | None,
|
319
|
+
exc_tb: TracebackType | None,
|
320
|
+
) -> None:
|
321
|
+
current_node = message_limit_leaf_node.get()
|
322
|
+
assert current_node is not None, (
|
323
|
+
"Message limit node should not be None when exiting context manager."
|
324
|
+
)
|
325
|
+
message_limit_leaf_node.set(current_node.parent)
|
326
|
+
|
327
|
+
@property
|
328
|
+
def limit(self) -> int | None:
|
329
|
+
"""Get the configured message limit value."""
|
330
|
+
return self._limit_value_wrapper.value
|
331
|
+
|
332
|
+
@limit.setter
|
333
|
+
def limit(self, value: int | None) -> None:
|
334
|
+
"""Update the message limit value.
|
335
|
+
|
336
|
+
This will affect the limit for all active message limit nodes derived from this
|
337
|
+
context manager.
|
338
|
+
|
339
|
+
This does not trigger a check of the message limit (which could now have been
|
340
|
+
exceeded).
|
341
|
+
"""
|
342
|
+
self._validate_message_limit(value)
|
343
|
+
self._limit_value_wrapper.value = value
|
344
|
+
|
345
|
+
def _validate_message_limit(self, value: int | None) -> None:
|
346
|
+
if value is not None and value < 0:
|
347
|
+
raise ValueError("Message limit value must be a non-negative integer.")
|
348
|
+
|
349
|
+
|
350
|
+
class _MessageLimitNode:
|
351
|
+
def __init__(
|
352
|
+
self,
|
353
|
+
limit: _LimitValueWrapper,
|
354
|
+
parent: _MessageLimitNode | None,
|
355
|
+
) -> None:
|
356
|
+
"""
|
357
|
+
Initialize a message limit node.
|
358
|
+
|
359
|
+
Forms part of a tree structure. Each node has a pointer to its parent, or None
|
360
|
+
if it is the root node.
|
361
|
+
|
362
|
+
Checks if the message count for this node has exceeded a (variable) limit.
|
363
|
+
|
364
|
+
Args:
|
365
|
+
limit: The maximum conversation length (number of messages) allowed while this
|
366
|
+
node is the lead node of the current execution context.
|
367
|
+
parent: The parent node in the tree.
|
368
|
+
"""
|
369
|
+
self._limit = limit
|
370
|
+
self.parent = parent
|
371
|
+
|
372
|
+
def check(self, count: int, raise_for_equal: bool) -> None:
|
373
|
+
"""Check if this message limit has been exceeded.
|
374
|
+
|
375
|
+
Does not check parents.
|
376
|
+
"""
|
377
|
+
from inspect_ai.log._transcript import SampleLimitEvent, transcript
|
378
|
+
|
379
|
+
if self._limit.value is None:
|
380
|
+
return
|
381
|
+
limit = self._limit.value
|
382
|
+
if count > limit or (raise_for_equal and count == limit):
|
383
|
+
reached_or_exceeded = "reached" if count == limit else "exceeded"
|
384
|
+
message = (
|
385
|
+
f"Message limit {reached_or_exceeded}. count: {count:,}; "
|
386
|
+
f"limit: {limit:,}"
|
387
|
+
)
|
388
|
+
transcript()._event(
|
389
|
+
SampleLimitEvent(type="message", limit=limit, message=message)
|
390
|
+
)
|
391
|
+
raise LimitExceededError(
|
392
|
+
"message", value=count, limit=limit, message=message
|
393
|
+
)
|
@@ -0,0 +1,57 @@
|
|
1
|
+
from itertools import tee
|
2
|
+
from typing import Iterable, SupportsIndex, overload
|
3
|
+
|
4
|
+
from inspect_ai.model._chat_message import ChatMessage, ChatMessageBase
|
5
|
+
from inspect_ai.util._limit import check_message_limit
|
6
|
+
|
7
|
+
|
8
|
+
class ChatMessageList(list[ChatMessage]):
|
9
|
+
"""A limited list of ChatMessage items.
|
10
|
+
|
11
|
+
Raises an exception if an operation would exceed the active message limit.
|
12
|
+
"""
|
13
|
+
|
14
|
+
def __init__(self, iterable: Iterable[ChatMessage]):
|
15
|
+
items, length = self._iterable_length(iterable)
|
16
|
+
self._check_size(length)
|
17
|
+
super().__init__(items)
|
18
|
+
|
19
|
+
def _check_size(self, additional_items: int) -> None:
|
20
|
+
check_message_limit(len(self) + additional_items, raise_for_equal=False)
|
21
|
+
|
22
|
+
def append(self, item: ChatMessage) -> None:
|
23
|
+
self._check_size(1)
|
24
|
+
super().append(item)
|
25
|
+
|
26
|
+
def extend(self, items: Iterable[ChatMessage]) -> None:
|
27
|
+
items, length = self._iterable_length(items)
|
28
|
+
self._check_size(length)
|
29
|
+
super().extend(items)
|
30
|
+
|
31
|
+
def insert(self, index: SupportsIndex, item: ChatMessage) -> None:
|
32
|
+
self._check_size(1)
|
33
|
+
super().insert(index, item)
|
34
|
+
|
35
|
+
@overload
|
36
|
+
def __setitem__(self, index: SupportsIndex, item: ChatMessage) -> None: ...
|
37
|
+
|
38
|
+
@overload
|
39
|
+
def __setitem__(self, index: slice, item: Iterable[ChatMessage]) -> None: ...
|
40
|
+
|
41
|
+
def __setitem__(
|
42
|
+
self, index: SupportsIndex | slice, item: ChatMessage | Iterable[ChatMessage]
|
43
|
+
) -> None:
|
44
|
+
if isinstance(index, slice) and not isinstance(item, ChatMessageBase):
|
45
|
+
item, length = self._iterable_length(item)
|
46
|
+
size_change = length - len(self[index])
|
47
|
+
if size_change > 0:
|
48
|
+
self._check_size(size_change)
|
49
|
+
|
50
|
+
super().__setitem__(index, item) # type: ignore[assignment,index]
|
51
|
+
|
52
|
+
def _iterable_length(
|
53
|
+
self, items: Iterable[ChatMessage]
|
54
|
+
) -> tuple[Iterable[ChatMessage], int]:
|
55
|
+
items, counter = tee(items)
|
56
|
+
length = sum(1 for _ in counter)
|
57
|
+
return items, length
|
inspect_ai/util/_span.py
ADDED
@@ -0,0 +1,58 @@
|
|
1
|
+
import contextlib
|
2
|
+
from contextvars import ContextVar
|
3
|
+
from typing import AsyncIterator
|
4
|
+
from uuid import uuid4
|
5
|
+
|
6
|
+
|
7
|
+
@contextlib.asynccontextmanager
|
8
|
+
async def span(name: str, *, type: str | None = None) -> AsyncIterator[None]:
|
9
|
+
"""Context manager for establishing a transcript span.
|
10
|
+
|
11
|
+
Args:
|
12
|
+
name (str): Step name.
|
13
|
+
type (str | None): Optional span type.
|
14
|
+
"""
|
15
|
+
from inspect_ai.log._transcript import (
|
16
|
+
SpanBeginEvent,
|
17
|
+
SpanEndEvent,
|
18
|
+
track_store_changes,
|
19
|
+
transcript,
|
20
|
+
)
|
21
|
+
|
22
|
+
# span id
|
23
|
+
id = uuid4().hex
|
24
|
+
|
25
|
+
# capture parent id
|
26
|
+
parent_id = _current_span_id.get()
|
27
|
+
|
28
|
+
# set new current span (reset at the end)
|
29
|
+
token = _current_span_id.set(id)
|
30
|
+
|
31
|
+
# run the span
|
32
|
+
try:
|
33
|
+
# span begin event
|
34
|
+
transcript()._event(
|
35
|
+
SpanBeginEvent(
|
36
|
+
id=id,
|
37
|
+
parent_id=parent_id,
|
38
|
+
type=type,
|
39
|
+
name=name,
|
40
|
+
)
|
41
|
+
)
|
42
|
+
|
43
|
+
# run span w/ store change events
|
44
|
+
with track_store_changes():
|
45
|
+
yield
|
46
|
+
|
47
|
+
finally:
|
48
|
+
# send end event
|
49
|
+
transcript()._event(SpanEndEvent(id=id))
|
50
|
+
|
51
|
+
_current_span_id.reset(token)
|
52
|
+
|
53
|
+
|
54
|
+
def current_span_id() -> str | None:
|
55
|
+
return _current_span_id.get()
|
56
|
+
|
57
|
+
|
58
|
+
_current_span_id: ContextVar[str | None] = ContextVar("_current_span_id", default=None)
|
inspect_ai/util/_subtask.py
CHANGED
@@ -16,6 +16,7 @@ from inspect_ai._util._async import is_callable_coroutine, tg_collect
|
|
16
16
|
from inspect_ai._util.content import Content
|
17
17
|
from inspect_ai._util.trace import trace_action
|
18
18
|
from inspect_ai._util.working import sample_waiting_time
|
19
|
+
from inspect_ai.util._span import span
|
19
20
|
from inspect_ai.util._store import Store, dict_jsonable, init_subtask_store
|
20
21
|
|
21
22
|
SubtaskResult = str | int | float | bool | list[Content]
|
@@ -85,9 +86,7 @@ def subtask(
|
|
85
86
|
|
86
87
|
def create_subtask_wrapper(func: Subtask, name: str | None = None) -> Subtask:
|
87
88
|
from inspect_ai.log._transcript import (
|
88
|
-
Event,
|
89
89
|
SubtaskEvent,
|
90
|
-
track_store_changes,
|
91
90
|
transcript,
|
92
91
|
)
|
93
92
|
|
@@ -118,43 +117,41 @@ def subtask(
|
|
118
117
|
log_input = dict_jsonable(log_input | kwargs)
|
119
118
|
|
120
119
|
# create coroutine so we can provision a subtask contextvars
|
121
|
-
async def run() ->
|
120
|
+
async def run() -> RT:
|
122
121
|
# initialise subtask (provisions store and transcript)
|
123
|
-
|
122
|
+
init_subtask_store(store if store else Store())
|
124
123
|
|
125
124
|
# run the subtask
|
126
125
|
with trace_action(logger, "Subtask", subtask_name):
|
127
|
-
with
|
126
|
+
async with span(name=subtask_name, type="subtask"):
|
127
|
+
# create subtask event
|
128
|
+
waiting_time_start = sample_waiting_time()
|
129
|
+
event = SubtaskEvent(
|
130
|
+
name=subtask_name, input=log_input, type=type, pending=True
|
131
|
+
)
|
132
|
+
transcript()._event(event)
|
133
|
+
|
134
|
+
# run the subtask
|
128
135
|
result = await func(*args, **kwargs)
|
129
136
|
|
130
|
-
|
131
|
-
|
137
|
+
# time accounting
|
138
|
+
completed = datetime.now()
|
139
|
+
waiting_time_end = sample_waiting_time()
|
140
|
+
event.completed = completed
|
141
|
+
event.working_time = (
|
142
|
+
completed - event.timestamp
|
143
|
+
).total_seconds() - (waiting_time_end - waiting_time_start)
|
132
144
|
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
)
|
138
|
-
transcript()._event(event)
|
139
|
-
|
140
|
-
# create and run the task as a coroutine
|
141
|
-
result, events = (await tg_collect([run]))[0]
|
142
|
-
|
143
|
-
# time accounting
|
144
|
-
completed = datetime.now()
|
145
|
-
waiting_time_end = sample_waiting_time()
|
146
|
-
event.completed = completed
|
147
|
-
event.working_time = (completed - event.timestamp).total_seconds() - (
|
148
|
-
waiting_time_end - waiting_time_start
|
149
|
-
)
|
145
|
+
# update event
|
146
|
+
event.result = result
|
147
|
+
event.pending = None
|
148
|
+
transcript()._event_updated(event)
|
150
149
|
|
151
|
-
|
152
|
-
|
153
|
-
event.events = events
|
154
|
-
event.pending = None
|
155
|
-
transcript()._event_updated(event)
|
150
|
+
# return result
|
151
|
+
return result # type: ignore[no-any-return]
|
156
152
|
|
157
|
-
#
|
153
|
+
# create and run the task as a coroutine
|
154
|
+
result = (await tg_collect([run]))[0]
|
158
155
|
return result
|
159
156
|
|
160
157
|
return run_subtask
|
@@ -167,15 +164,3 @@ def subtask(
|
|
167
164
|
return wrapper
|
168
165
|
else:
|
169
166
|
return create_subtask_wrapper(name)
|
170
|
-
|
171
|
-
|
172
|
-
def init_subtask(name: str, store: Store) -> Any:
|
173
|
-
from inspect_ai.log._transcript import (
|
174
|
-
Transcript,
|
175
|
-
init_transcript,
|
176
|
-
)
|
177
|
-
|
178
|
-
init_subtask_store(store)
|
179
|
-
transcript = Transcript(name=name)
|
180
|
-
init_transcript(transcript)
|
181
|
-
return transcript
|