inspect-ai 0.3.93__py3-none-any.whl → 0.3.95__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_display/textual/widgets/samples.py +3 -3
- inspect_ai/_display/textual/widgets/transcript.py +3 -29
- inspect_ai/_eval/loader.py +1 -1
- inspect_ai/_eval/task/run.py +21 -12
- inspect_ai/_util/answer.py +26 -0
- inspect_ai/_util/constants.py +0 -1
- inspect_ai/_util/exception.py +4 -0
- inspect_ai/_util/hash.py +39 -0
- inspect_ai/_util/local_server.py +51 -21
- inspect_ai/_util/path.py +22 -0
- inspect_ai/_util/trace.py +1 -1
- inspect_ai/_util/working.py +4 -0
- inspect_ai/_view/www/dist/assets/index.css +23 -22
- inspect_ai/_view/www/dist/assets/index.js +517 -204
- inspect_ai/_view/www/log-schema.json +375 -0
- inspect_ai/_view/www/package.json +1 -1
- inspect_ai/_view/www/src/@types/log.d.ts +90 -12
- inspect_ai/_view/www/src/app/log-view/navbar/SecondaryBar.tsx +2 -2
- inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +1 -4
- inspect_ai/_view/www/src/app/samples/SamplesTools.tsx +3 -13
- inspect_ai/_view/www/src/app/samples/sample-tools/SelectScorer.tsx +45 -48
- inspect_ai/_view/www/src/app/samples/sample-tools/filters.ts +16 -15
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/SampleFilter.tsx +47 -75
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/completions.ts +9 -9
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
- inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
- inspect_ai/_view/www/src/app/types.ts +12 -2
- inspect_ai/_view/www/src/components/ExpandablePanel.module.css +1 -1
- inspect_ai/_view/www/src/components/ExpandablePanel.tsx +5 -5
- inspect_ai/_view/www/src/state/hooks.ts +19 -3
- inspect_ai/_view/www/src/state/logSlice.ts +23 -5
- inspect_ai/_view/www/yarn.lock +9 -9
- inspect_ai/agent/_as_solver.py +3 -1
- inspect_ai/agent/_as_tool.py +6 -4
- inspect_ai/agent/_bridge/patch.py +1 -3
- inspect_ai/agent/_handoff.py +5 -1
- inspect_ai/agent/_react.py +4 -3
- inspect_ai/agent/_run.py +6 -1
- inspect_ai/agent/_types.py +9 -0
- inspect_ai/analysis/__init__.py +0 -0
- inspect_ai/analysis/beta/__init__.py +57 -0
- inspect_ai/analysis/beta/_dataframe/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/columns.py +145 -0
- inspect_ai/analysis/beta/_dataframe/evals/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/evals/columns.py +132 -0
- inspect_ai/analysis/beta/_dataframe/evals/extract.py +23 -0
- inspect_ai/analysis/beta/_dataframe/evals/table.py +140 -0
- inspect_ai/analysis/beta/_dataframe/events/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/events/columns.py +37 -0
- inspect_ai/analysis/beta/_dataframe/events/table.py +14 -0
- inspect_ai/analysis/beta/_dataframe/extract.py +54 -0
- inspect_ai/analysis/beta/_dataframe/messages/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/messages/columns.py +60 -0
- inspect_ai/analysis/beta/_dataframe/messages/extract.py +21 -0
- inspect_ai/analysis/beta/_dataframe/messages/table.py +87 -0
- inspect_ai/analysis/beta/_dataframe/record.py +377 -0
- inspect_ai/analysis/beta/_dataframe/samples/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/samples/columns.py +73 -0
- inspect_ai/analysis/beta/_dataframe/samples/extract.py +82 -0
- inspect_ai/analysis/beta/_dataframe/samples/table.py +329 -0
- inspect_ai/analysis/beta/_dataframe/util.py +157 -0
- inspect_ai/analysis/beta/_dataframe/validate.py +171 -0
- inspect_ai/dataset/_dataset.py +6 -3
- inspect_ai/log/__init__.py +10 -0
- inspect_ai/log/_convert.py +4 -9
- inspect_ai/log/_file.py +1 -1
- inspect_ai/log/_log.py +21 -1
- inspect_ai/log/_samples.py +14 -17
- inspect_ai/log/_transcript.py +77 -35
- inspect_ai/log/_tree.py +118 -0
- inspect_ai/model/_call_tools.py +44 -35
- inspect_ai/model/_model.py +51 -44
- inspect_ai/model/_openai_responses.py +17 -18
- inspect_ai/model/_providers/anthropic.py +30 -5
- inspect_ai/model/_providers/hf.py +27 -1
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/sglang.py +8 -2
- inspect_ai/model/_providers/vllm.py +6 -2
- inspect_ai/scorer/_choice.py +1 -2
- inspect_ai/solver/_chain.py +1 -1
- inspect_ai/solver/_fork.py +1 -1
- inspect_ai/solver/_multiple_choice.py +9 -23
- inspect_ai/solver/_plan.py +2 -2
- inspect_ai/solver/_task_state.py +7 -3
- inspect_ai/solver/_transcript.py +6 -7
- inspect_ai/tool/_mcp/_context.py +3 -5
- inspect_ai/tool/_mcp/_mcp.py +6 -5
- inspect_ai/tool/_mcp/server.py +1 -1
- inspect_ai/tool/_tools/_execute.py +4 -1
- inspect_ai/tool/_tools/_think.py +1 -1
- inspect_ai/tool/_tools/_web_search/__init__.py +3 -0
- inspect_ai/tool/_tools/{_web_search.py → _web_search/_google.py} +56 -103
- inspect_ai/tool/_tools/_web_search/_tavily.py +77 -0
- inspect_ai/tool/_tools/_web_search/_web_search.py +85 -0
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_anyio.py +11 -0
- inspect_ai/util/_collect.py +50 -0
- inspect_ai/util/_sandbox/events.py +3 -2
- inspect_ai/util/_span.py +58 -0
- inspect_ai/util/_subtask.py +27 -42
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/METADATA +8 -1
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/RECORD +114 -82
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/WHEEL +1 -1
- inspect_ai/_display/core/group.py +0 -79
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,77 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Awaitable, Callable
|
3
|
+
|
4
|
+
import httpx
|
5
|
+
from pydantic import BaseModel, Field
|
6
|
+
from tenacity import (
|
7
|
+
retry,
|
8
|
+
retry_if_exception,
|
9
|
+
stop_after_attempt,
|
10
|
+
stop_after_delay,
|
11
|
+
wait_exponential_jitter,
|
12
|
+
)
|
13
|
+
|
14
|
+
from inspect_ai._util.error import PrerequisiteError
|
15
|
+
from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
|
16
|
+
from inspect_ai.util._concurrency import concurrency
|
17
|
+
|
18
|
+
|
19
|
+
class TavilySearchResult(BaseModel):
|
20
|
+
title: str
|
21
|
+
url: str
|
22
|
+
content: str
|
23
|
+
score: float
|
24
|
+
|
25
|
+
|
26
|
+
class TavilySearchResponse(BaseModel):
|
27
|
+
query: str
|
28
|
+
answer: str | None = Field(default=None)
|
29
|
+
images: list[object]
|
30
|
+
results: list[TavilySearchResult]
|
31
|
+
response_time: float
|
32
|
+
|
33
|
+
|
34
|
+
def tavily_search_provider(
|
35
|
+
num_results: int, max_connections: int
|
36
|
+
) -> Callable[[str], Awaitable[str | None]]:
|
37
|
+
tavily_api_key = os.environ.get("TAVILY_API_KEY", None)
|
38
|
+
if not tavily_api_key:
|
39
|
+
raise PrerequisiteError(
|
40
|
+
"TAVILY_API_KEY not set in the environment. Please ensure ths variable is defined to use Tavily with the web_search tool.\n\nLearn more about the Tavily web search provider at https://inspect.aisi.org.uk/tools.html#tavily-provider"
|
41
|
+
)
|
42
|
+
if num_results > 20:
|
43
|
+
raise PrerequisiteError(
|
44
|
+
"The Tavily search provider is limited to 20 results per query."
|
45
|
+
)
|
46
|
+
|
47
|
+
# Create the client within the provider
|
48
|
+
client = httpx.AsyncClient(timeout=30)
|
49
|
+
|
50
|
+
async def search(query: str) -> str | None:
|
51
|
+
search_url = "https://api.tavily.com/search"
|
52
|
+
headers = {
|
53
|
+
"Authorization": f"Bearer {tavily_api_key}",
|
54
|
+
}
|
55
|
+
body = {
|
56
|
+
"query": query,
|
57
|
+
"max_results": 10, # num_results,
|
58
|
+
# "search_depth": "advanced",
|
59
|
+
"include_answer": "advanced",
|
60
|
+
}
|
61
|
+
|
62
|
+
# retry up to 5 times over a period of up to 1 minute
|
63
|
+
@retry(
|
64
|
+
wait=wait_exponential_jitter(),
|
65
|
+
stop=stop_after_attempt(5) | stop_after_delay(60),
|
66
|
+
retry=retry_if_exception(httpx_should_retry),
|
67
|
+
before_sleep=log_httpx_retry_attempt(search_url),
|
68
|
+
)
|
69
|
+
async def _search() -> httpx.Response:
|
70
|
+
response = await client.post(search_url, headers=headers, json=body)
|
71
|
+
response.raise_for_status()
|
72
|
+
return response
|
73
|
+
|
74
|
+
async with concurrency("tavily_web_search", max_connections):
|
75
|
+
return TavilySearchResponse.model_validate((await _search()).json()).answer
|
76
|
+
|
77
|
+
return search
|
@@ -0,0 +1,85 @@
|
|
1
|
+
from typing import Literal
|
2
|
+
|
3
|
+
from inspect_ai._util.deprecation import deprecation_warning
|
4
|
+
|
5
|
+
from ..._tool import Tool, ToolResult, tool
|
6
|
+
from ._google import google_search_provider, maybe_get_google_api_keys
|
7
|
+
from ._tavily import tavily_search_provider
|
8
|
+
|
9
|
+
|
10
|
+
@tool
|
11
|
+
def web_search(
|
12
|
+
provider: Literal["tavily", "google"] | None = None,
|
13
|
+
num_results: int = 3,
|
14
|
+
max_provider_calls: int = 3,
|
15
|
+
max_connections: int = 10,
|
16
|
+
model: str | None = None,
|
17
|
+
) -> Tool:
|
18
|
+
"""Web search tool.
|
19
|
+
|
20
|
+
A tool that can be registered for use by models to search the web. Use
|
21
|
+
the `use_tools()` solver to make the tool available (e.g.
|
22
|
+
`use_tools(web_search(provider="tavily"))`))
|
23
|
+
|
24
|
+
A web search is conducted using the specified provider.
|
25
|
+
- When using Tavily, all logic for relevance and summarization is handled by
|
26
|
+
the Tavily API.
|
27
|
+
- When using Google, the results are parsed for relevance using the specified
|
28
|
+
model, and the top 'num_results' relevant pages are returned.
|
29
|
+
|
30
|
+
See further documentation at <https://inspect.aisi.org.uk/tools-standard.html#sec-web-search>.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
provider: Search provider to use:
|
34
|
+
- "tavily": Uses Tavily's Research API.
|
35
|
+
- "google": Uses Google Custom Search.
|
36
|
+
Note: The `| None` type is only for backwards compatibility. Passing
|
37
|
+
`None` is deprecated.
|
38
|
+
num_results: The number of search result pages used to provide information
|
39
|
+
back to the model.
|
40
|
+
max_provider_calls: Maximum number of search calls to make to the search
|
41
|
+
provider.
|
42
|
+
max_connections: Maximum number of concurrent connections to API endpoint
|
43
|
+
of search provider.
|
44
|
+
model: Model used to parse web pages for relevance - used only by the
|
45
|
+
`google` provider.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
A tool that can be registered for use by models to search the web.
|
49
|
+
"""
|
50
|
+
if provider is None:
|
51
|
+
if maybe_get_google_api_keys():
|
52
|
+
deprecation_warning(
|
53
|
+
"The `google` `web_search` provider was inferred based on the presence of environment variables. Please specify the provider explicitly to avoid this warning."
|
54
|
+
)
|
55
|
+
provider = "google"
|
56
|
+
else:
|
57
|
+
raise ValueError(
|
58
|
+
"Omitting `provider` is no longer supported. Please specify the `web_search` provider explicitly to avoid this error."
|
59
|
+
)
|
60
|
+
|
61
|
+
search_provider = (
|
62
|
+
google_search_provider(num_results, max_provider_calls, max_connections, model)
|
63
|
+
if provider == "google"
|
64
|
+
else tavily_search_provider(num_results, max_connections)
|
65
|
+
)
|
66
|
+
|
67
|
+
async def execute(query: str) -> ToolResult:
|
68
|
+
"""
|
69
|
+
Use the web_search tool to perform keyword searches of the web.
|
70
|
+
|
71
|
+
Args:
|
72
|
+
query (str): Search query.
|
73
|
+
"""
|
74
|
+
search_result = await search_provider(query)
|
75
|
+
|
76
|
+
return (
|
77
|
+
(
|
78
|
+
"Here are your web search results. Please read them carefully as they may be useful later!\n"
|
79
|
+
+ search_result
|
80
|
+
)
|
81
|
+
if search_result
|
82
|
+
else ("I'm sorry, I couldn't find any relevant information on the web.")
|
83
|
+
)
|
84
|
+
|
85
|
+
return execute
|
inspect_ai/util/__init__.py
CHANGED
@@ -8,6 +8,7 @@ from inspect_ai.util._limit import (
|
|
8
8
|
token_limit,
|
9
9
|
)
|
10
10
|
|
11
|
+
from ._collect import collect
|
11
12
|
from ._concurrency import concurrency
|
12
13
|
from ._console import input_screen
|
13
14
|
from ._display import DisplayType, display_counter, display_type
|
@@ -28,6 +29,7 @@ from ._sandbox import (
|
|
28
29
|
sandbox_with,
|
29
30
|
sandboxenv,
|
30
31
|
)
|
32
|
+
from ._span import span
|
31
33
|
from ._store import Store, store
|
32
34
|
from ._store_model import StoreModel, store_as
|
33
35
|
from ._subprocess import (
|
@@ -71,6 +73,8 @@ __all__ = [
|
|
71
73
|
"store",
|
72
74
|
"StoreModel",
|
73
75
|
"store_as",
|
76
|
+
"span",
|
77
|
+
"collect",
|
74
78
|
"Subtask",
|
75
79
|
"subtask",
|
76
80
|
"throttle",
|
inspect_ai/util/_anyio.py
CHANGED
@@ -1,6 +1,10 @@
|
|
1
1
|
import itertools
|
2
2
|
import sys
|
3
3
|
|
4
|
+
import anyio
|
5
|
+
|
6
|
+
from inspect_ai._util._async import current_async_backend
|
7
|
+
|
4
8
|
if sys.version_info < (3, 11):
|
5
9
|
from exceptiongroup import ExceptionGroup
|
6
10
|
|
@@ -36,3 +40,10 @@ def _flatten_exception(exc: Exception) -> list[Exception]:
|
|
36
40
|
]
|
37
41
|
|
38
42
|
return maybe_this_exception + other_exceptions
|
43
|
+
|
44
|
+
|
45
|
+
def safe_current_task_id() -> int | None:
|
46
|
+
if current_async_backend() is not None:
|
47
|
+
return anyio.get_current_task().id
|
48
|
+
else:
|
49
|
+
return None
|
@@ -0,0 +1,50 @@
|
|
1
|
+
import sys
|
2
|
+
from typing import Awaitable, TypeVar, cast
|
3
|
+
|
4
|
+
import anyio
|
5
|
+
|
6
|
+
from ._span import span
|
7
|
+
|
8
|
+
if sys.version_info < (3, 11):
|
9
|
+
from exceptiongroup import ExceptionGroup
|
10
|
+
|
11
|
+
|
12
|
+
T = TypeVar("T")
|
13
|
+
|
14
|
+
|
15
|
+
async def collect(*tasks: Awaitable[T]) -> list[T]:
|
16
|
+
"""Run and collect the results of one or more async coroutines.
|
17
|
+
|
18
|
+
Similar to [`asyncio.gather()`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather),
|
19
|
+
but also works when [Trio](https://trio.readthedocs.io/en/stable/) is the async backend.
|
20
|
+
|
21
|
+
Automatically includes each task in a `span()`, which
|
22
|
+
ensures that its events are grouped together in the transcript.
|
23
|
+
|
24
|
+
Using `collect()` in preference to `asyncio.gather()` is highly recommended
|
25
|
+
for both Trio compatibility and more legible transcript output.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
*tasks: Tasks to run
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
List of task results.
|
32
|
+
"""
|
33
|
+
results: list[None | T] = [None] * len(tasks)
|
34
|
+
|
35
|
+
try:
|
36
|
+
async with anyio.create_task_group() as tg:
|
37
|
+
|
38
|
+
async def run_task(index: int, task: Awaitable[T]) -> None:
|
39
|
+
async with span(f"task-{index + 1}", type="task"):
|
40
|
+
results[index] = await task
|
41
|
+
|
42
|
+
for i, task in enumerate(tasks):
|
43
|
+
tg.start_soon(run_task, i, task)
|
44
|
+
except ExceptionGroup as ex:
|
45
|
+
if len(ex.exceptions) == 1:
|
46
|
+
raise ex.exceptions[0] from None
|
47
|
+
else:
|
48
|
+
raise
|
49
|
+
|
50
|
+
return cast(list[T], results)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import contextlib
|
2
2
|
import shlex
|
3
3
|
from datetime import datetime
|
4
|
-
from typing import Iterator, Literal, Type, Union, overload
|
4
|
+
from typing import Any, Iterator, Literal, Type, Union, overload
|
5
5
|
|
6
6
|
from pydantic import JsonValue
|
7
7
|
from pydantic_core import to_jsonable_python
|
@@ -134,7 +134,8 @@ class SandboxEnvironmentProxy(SandboxEnvironment):
|
|
134
134
|
|
135
135
|
@override
|
136
136
|
async def connection(self, *, user: str | None = None) -> SandboxConnection:
|
137
|
-
|
137
|
+
params: dict[str, Any] = {"user": user} if user is not None else {}
|
138
|
+
return await self._sandbox.connection(**params)
|
138
139
|
|
139
140
|
@override
|
140
141
|
def as_type(self, sandbox_cls: Type[ST]) -> ST:
|
inspect_ai/util/_span.py
ADDED
@@ -0,0 +1,58 @@
|
|
1
|
+
import contextlib
|
2
|
+
from contextvars import ContextVar
|
3
|
+
from typing import AsyncIterator
|
4
|
+
from uuid import uuid4
|
5
|
+
|
6
|
+
|
7
|
+
@contextlib.asynccontextmanager
|
8
|
+
async def span(name: str, *, type: str | None = None) -> AsyncIterator[None]:
|
9
|
+
"""Context manager for establishing a transcript span.
|
10
|
+
|
11
|
+
Args:
|
12
|
+
name (str): Step name.
|
13
|
+
type (str | None): Optional span type.
|
14
|
+
"""
|
15
|
+
from inspect_ai.log._transcript import (
|
16
|
+
SpanBeginEvent,
|
17
|
+
SpanEndEvent,
|
18
|
+
track_store_changes,
|
19
|
+
transcript,
|
20
|
+
)
|
21
|
+
|
22
|
+
# span id
|
23
|
+
id = uuid4().hex
|
24
|
+
|
25
|
+
# capture parent id
|
26
|
+
parent_id = _current_span_id.get()
|
27
|
+
|
28
|
+
# set new current span (reset at the end)
|
29
|
+
token = _current_span_id.set(id)
|
30
|
+
|
31
|
+
# run the span
|
32
|
+
try:
|
33
|
+
# span begin event
|
34
|
+
transcript()._event(
|
35
|
+
SpanBeginEvent(
|
36
|
+
id=id,
|
37
|
+
parent_id=parent_id,
|
38
|
+
type=type,
|
39
|
+
name=name,
|
40
|
+
)
|
41
|
+
)
|
42
|
+
|
43
|
+
# run span w/ store change events
|
44
|
+
with track_store_changes():
|
45
|
+
yield
|
46
|
+
|
47
|
+
finally:
|
48
|
+
# send end event
|
49
|
+
transcript()._event(SpanEndEvent(id=id))
|
50
|
+
|
51
|
+
_current_span_id.reset(token)
|
52
|
+
|
53
|
+
|
54
|
+
def current_span_id() -> str | None:
|
55
|
+
return _current_span_id.get()
|
56
|
+
|
57
|
+
|
58
|
+
_current_span_id: ContextVar[str | None] = ContextVar("_current_span_id", default=None)
|
inspect_ai/util/_subtask.py
CHANGED
@@ -16,6 +16,7 @@ from inspect_ai._util._async import is_callable_coroutine, tg_collect
|
|
16
16
|
from inspect_ai._util.content import Content
|
17
17
|
from inspect_ai._util.trace import trace_action
|
18
18
|
from inspect_ai._util.working import sample_waiting_time
|
19
|
+
from inspect_ai.util._span import span
|
19
20
|
from inspect_ai.util._store import Store, dict_jsonable, init_subtask_store
|
20
21
|
|
21
22
|
SubtaskResult = str | int | float | bool | list[Content]
|
@@ -85,9 +86,7 @@ def subtask(
|
|
85
86
|
|
86
87
|
def create_subtask_wrapper(func: Subtask, name: str | None = None) -> Subtask:
|
87
88
|
from inspect_ai.log._transcript import (
|
88
|
-
Event,
|
89
89
|
SubtaskEvent,
|
90
|
-
track_store_changes,
|
91
90
|
transcript,
|
92
91
|
)
|
93
92
|
|
@@ -118,43 +117,41 @@ def subtask(
|
|
118
117
|
log_input = dict_jsonable(log_input | kwargs)
|
119
118
|
|
120
119
|
# create coroutine so we can provision a subtask contextvars
|
121
|
-
async def run() ->
|
120
|
+
async def run() -> RT:
|
122
121
|
# initialise subtask (provisions store and transcript)
|
123
|
-
|
122
|
+
init_subtask_store(store if store else Store())
|
124
123
|
|
125
124
|
# run the subtask
|
126
125
|
with trace_action(logger, "Subtask", subtask_name):
|
127
|
-
with
|
126
|
+
async with span(name=subtask_name, type="subtask"):
|
127
|
+
# create subtask event
|
128
|
+
waiting_time_start = sample_waiting_time()
|
129
|
+
event = SubtaskEvent(
|
130
|
+
name=subtask_name, input=log_input, type=type, pending=True
|
131
|
+
)
|
132
|
+
transcript()._event(event)
|
133
|
+
|
134
|
+
# run the subtask
|
128
135
|
result = await func(*args, **kwargs)
|
129
136
|
|
130
|
-
|
131
|
-
|
137
|
+
# time accounting
|
138
|
+
completed = datetime.now()
|
139
|
+
waiting_time_end = sample_waiting_time()
|
140
|
+
event.completed = completed
|
141
|
+
event.working_time = (
|
142
|
+
completed - event.timestamp
|
143
|
+
).total_seconds() - (waiting_time_end - waiting_time_start)
|
132
144
|
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
)
|
138
|
-
transcript()._event(event)
|
139
|
-
|
140
|
-
# create and run the task as a coroutine
|
141
|
-
result, events = (await tg_collect([run]))[0]
|
142
|
-
|
143
|
-
# time accounting
|
144
|
-
completed = datetime.now()
|
145
|
-
waiting_time_end = sample_waiting_time()
|
146
|
-
event.completed = completed
|
147
|
-
event.working_time = (completed - event.timestamp).total_seconds() - (
|
148
|
-
waiting_time_end - waiting_time_start
|
149
|
-
)
|
145
|
+
# update event
|
146
|
+
event.result = result
|
147
|
+
event.pending = None
|
148
|
+
transcript()._event_updated(event)
|
150
149
|
|
151
|
-
|
152
|
-
|
153
|
-
event.events = events
|
154
|
-
event.pending = None
|
155
|
-
transcript()._event_updated(event)
|
150
|
+
# return result
|
151
|
+
return result # type: ignore[no-any-return]
|
156
152
|
|
157
|
-
#
|
153
|
+
# create and run the task as a coroutine
|
154
|
+
result = (await tg_collect([run]))[0]
|
158
155
|
return result
|
159
156
|
|
160
157
|
return run_subtask
|
@@ -167,15 +164,3 @@ def subtask(
|
|
167
164
|
return wrapper
|
168
165
|
else:
|
169
166
|
return create_subtask_wrapper(name)
|
170
|
-
|
171
|
-
|
172
|
-
def init_subtask(name: str, store: Store) -> Any:
|
173
|
-
from inspect_ai.log._transcript import (
|
174
|
-
Transcript,
|
175
|
-
init_transcript,
|
176
|
-
)
|
177
|
-
|
178
|
-
init_subtask_store(store)
|
179
|
-
transcript = Transcript(name=name)
|
180
|
-
init_transcript(transcript)
|
181
|
-
return transcript
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: inspect_ai
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.95
|
4
4
|
Summary: Framework for large language model evaluations
|
5
5
|
Author: UK AI Security Institute
|
6
6
|
License: MIT License
|
@@ -32,6 +32,8 @@ Requires-Dist: httpx
|
|
32
32
|
Requires-Dist: ijson>=3.2.0
|
33
33
|
Requires-Dist: jsonlines>=3.0.0
|
34
34
|
Requires-Dist: jsonpatch>=1.32
|
35
|
+
Requires-Dist: jsonpath-ng>=1.7.0
|
36
|
+
Requires-Dist: jsonref>=1.1.0
|
35
37
|
Requires-Dist: jsonschema>3.1.1
|
36
38
|
Requires-Dist: mmh3>3.1.0
|
37
39
|
Requires-Dist: nest_asyncio
|
@@ -59,6 +61,7 @@ Requires-Dist: google-genai; extra == "dev"
|
|
59
61
|
Requires-Dist: griffe; extra == "dev"
|
60
62
|
Requires-Dist: groq; extra == "dev"
|
61
63
|
Requires-Dist: ipython; extra == "dev"
|
64
|
+
Requires-Dist: jsonpath-ng; extra == "dev"
|
62
65
|
Requires-Dist: markdown; extra == "dev"
|
63
66
|
Requires-Dist: mcp; extra == "dev"
|
64
67
|
Requires-Dist: mistralai; extra == "dev"
|
@@ -66,9 +69,11 @@ Requires-Dist: moto[server]; extra == "dev"
|
|
66
69
|
Requires-Dist: mypy; extra == "dev"
|
67
70
|
Requires-Dist: nbformat; extra == "dev"
|
68
71
|
Requires-Dist: openai; extra == "dev"
|
72
|
+
Requires-Dist: pandas>=2.0.0; extra == "dev"
|
69
73
|
Requires-Dist: panflute; extra == "dev"
|
70
74
|
Requires-Dist: pip; extra == "dev"
|
71
75
|
Requires-Dist: pre-commit; extra == "dev"
|
76
|
+
Requires-Dist: pyarrow>=10.0.1; extra == "dev"
|
72
77
|
Requires-Dist: pylint; extra == "dev"
|
73
78
|
Requires-Dist: pytest; extra == "dev"
|
74
79
|
Requires-Dist: pytest-asyncio; extra == "dev"
|
@@ -78,6 +83,8 @@ Requires-Dist: pytest-xdist; extra == "dev"
|
|
78
83
|
Requires-Dist: ruff==0.9.6; extra == "dev"
|
79
84
|
Requires-Dist: textual-dev>=0.86.2; extra == "dev"
|
80
85
|
Requires-Dist: trio; extra == "dev"
|
86
|
+
Requires-Dist: pandas-stubs; extra == "dev"
|
87
|
+
Requires-Dist: pyarrow-stubs; extra == "dev"
|
81
88
|
Requires-Dist: types-Markdown; extra == "dev"
|
82
89
|
Requires-Dist: types-PyYAML; extra == "dev"
|
83
90
|
Requires-Dist: types-beautifulsoup4; extra == "dev"
|