inspect-ai 0.3.93__py3-none-any.whl → 0.3.95__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_display/textual/widgets/samples.py +3 -3
- inspect_ai/_display/textual/widgets/transcript.py +3 -29
- inspect_ai/_eval/loader.py +1 -1
- inspect_ai/_eval/task/run.py +21 -12
- inspect_ai/_util/answer.py +26 -0
- inspect_ai/_util/constants.py +0 -1
- inspect_ai/_util/exception.py +4 -0
- inspect_ai/_util/hash.py +39 -0
- inspect_ai/_util/local_server.py +51 -21
- inspect_ai/_util/path.py +22 -0
- inspect_ai/_util/trace.py +1 -1
- inspect_ai/_util/working.py +4 -0
- inspect_ai/_view/www/dist/assets/index.css +23 -22
- inspect_ai/_view/www/dist/assets/index.js +517 -204
- inspect_ai/_view/www/log-schema.json +375 -0
- inspect_ai/_view/www/package.json +1 -1
- inspect_ai/_view/www/src/@types/log.d.ts +90 -12
- inspect_ai/_view/www/src/app/log-view/navbar/SecondaryBar.tsx +2 -2
- inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +1 -4
- inspect_ai/_view/www/src/app/samples/SamplesTools.tsx +3 -13
- inspect_ai/_view/www/src/app/samples/sample-tools/SelectScorer.tsx +45 -48
- inspect_ai/_view/www/src/app/samples/sample-tools/filters.ts +16 -15
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/SampleFilter.tsx +47 -75
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/completions.ts +9 -9
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
- inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
- inspect_ai/_view/www/src/app/types.ts +12 -2
- inspect_ai/_view/www/src/components/ExpandablePanel.module.css +1 -1
- inspect_ai/_view/www/src/components/ExpandablePanel.tsx +5 -5
- inspect_ai/_view/www/src/state/hooks.ts +19 -3
- inspect_ai/_view/www/src/state/logSlice.ts +23 -5
- inspect_ai/_view/www/yarn.lock +9 -9
- inspect_ai/agent/_as_solver.py +3 -1
- inspect_ai/agent/_as_tool.py +6 -4
- inspect_ai/agent/_bridge/patch.py +1 -3
- inspect_ai/agent/_handoff.py +5 -1
- inspect_ai/agent/_react.py +4 -3
- inspect_ai/agent/_run.py +6 -1
- inspect_ai/agent/_types.py +9 -0
- inspect_ai/analysis/__init__.py +0 -0
- inspect_ai/analysis/beta/__init__.py +57 -0
- inspect_ai/analysis/beta/_dataframe/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/columns.py +145 -0
- inspect_ai/analysis/beta/_dataframe/evals/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/evals/columns.py +132 -0
- inspect_ai/analysis/beta/_dataframe/evals/extract.py +23 -0
- inspect_ai/analysis/beta/_dataframe/evals/table.py +140 -0
- inspect_ai/analysis/beta/_dataframe/events/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/events/columns.py +37 -0
- inspect_ai/analysis/beta/_dataframe/events/table.py +14 -0
- inspect_ai/analysis/beta/_dataframe/extract.py +54 -0
- inspect_ai/analysis/beta/_dataframe/messages/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/messages/columns.py +60 -0
- inspect_ai/analysis/beta/_dataframe/messages/extract.py +21 -0
- inspect_ai/analysis/beta/_dataframe/messages/table.py +87 -0
- inspect_ai/analysis/beta/_dataframe/record.py +377 -0
- inspect_ai/analysis/beta/_dataframe/samples/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/samples/columns.py +73 -0
- inspect_ai/analysis/beta/_dataframe/samples/extract.py +82 -0
- inspect_ai/analysis/beta/_dataframe/samples/table.py +329 -0
- inspect_ai/analysis/beta/_dataframe/util.py +157 -0
- inspect_ai/analysis/beta/_dataframe/validate.py +171 -0
- inspect_ai/dataset/_dataset.py +6 -3
- inspect_ai/log/__init__.py +10 -0
- inspect_ai/log/_convert.py +4 -9
- inspect_ai/log/_file.py +1 -1
- inspect_ai/log/_log.py +21 -1
- inspect_ai/log/_samples.py +14 -17
- inspect_ai/log/_transcript.py +77 -35
- inspect_ai/log/_tree.py +118 -0
- inspect_ai/model/_call_tools.py +44 -35
- inspect_ai/model/_model.py +51 -44
- inspect_ai/model/_openai_responses.py +17 -18
- inspect_ai/model/_providers/anthropic.py +30 -5
- inspect_ai/model/_providers/hf.py +27 -1
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/sglang.py +8 -2
- inspect_ai/model/_providers/vllm.py +6 -2
- inspect_ai/scorer/_choice.py +1 -2
- inspect_ai/solver/_chain.py +1 -1
- inspect_ai/solver/_fork.py +1 -1
- inspect_ai/solver/_multiple_choice.py +9 -23
- inspect_ai/solver/_plan.py +2 -2
- inspect_ai/solver/_task_state.py +7 -3
- inspect_ai/solver/_transcript.py +6 -7
- inspect_ai/tool/_mcp/_context.py +3 -5
- inspect_ai/tool/_mcp/_mcp.py +6 -5
- inspect_ai/tool/_mcp/server.py +1 -1
- inspect_ai/tool/_tools/_execute.py +4 -1
- inspect_ai/tool/_tools/_think.py +1 -1
- inspect_ai/tool/_tools/_web_search/__init__.py +3 -0
- inspect_ai/tool/_tools/{_web_search.py → _web_search/_google.py} +56 -103
- inspect_ai/tool/_tools/_web_search/_tavily.py +77 -0
- inspect_ai/tool/_tools/_web_search/_web_search.py +85 -0
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_anyio.py +11 -0
- inspect_ai/util/_collect.py +50 -0
- inspect_ai/util/_sandbox/events.py +3 -2
- inspect_ai/util/_span.py +58 -0
- inspect_ai/util/_subtask.py +27 -42
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/METADATA +8 -1
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/RECORD +114 -82
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/WHEEL +1 -1
- inspect_ai/_display/core/group.py +0 -79
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/top_level.txt +0 -0
inspect_ai/agent/_as_tool.py
CHANGED
@@ -11,6 +11,7 @@ from inspect_ai.tool._tool_def import ToolDef, validate_tool_parameters
|
|
11
11
|
from inspect_ai.tool._tool_info import ToolInfo, parse_tool_info
|
12
12
|
from inspect_ai.tool._tool_params import ToolParam
|
13
13
|
from inspect_ai.util._limit import Limit, apply_limits
|
14
|
+
from inspect_ai.util._span import span
|
14
15
|
|
15
16
|
from ._agent import AGENT_DESCRIPTION, Agent, AgentState
|
16
17
|
|
@@ -49,13 +50,17 @@ def as_tool(
|
|
49
50
|
"Agent passed to as_tool was not created by an @agent decorated function"
|
50
51
|
)
|
51
52
|
|
53
|
+
# get tool_info
|
54
|
+
tool_info = agent_tool_info(agent, description, **agent_kwargs)
|
55
|
+
|
52
56
|
async def execute(input: str, *args: Any, **kwargs: Any) -> ToolResult:
|
53
57
|
# prepare state
|
54
58
|
state = AgentState(messages=[ChatMessageUser(content=input, source="input")])
|
55
59
|
|
56
60
|
# run the agent with limits
|
57
61
|
with apply_limits(limits):
|
58
|
-
|
62
|
+
async with span(name=tool_info.name, type="agent"):
|
63
|
+
state = await agent(state, *args, **(agent_kwargs | kwargs))
|
59
64
|
|
60
65
|
# find assistant message to read content from (prefer output)
|
61
66
|
if not state.output.empty:
|
@@ -67,9 +72,6 @@ def as_tool(
|
|
67
72
|
else:
|
68
73
|
return ""
|
69
74
|
|
70
|
-
# get tool_info
|
71
|
-
tool_info = agent_tool_info(agent, description, **agent_kwargs)
|
72
|
-
|
73
75
|
# add "input" param
|
74
76
|
tool_info.parameters.properties = {
|
75
77
|
"input": ToolParam(type="string", description="Input message.")
|
@@ -3,7 +3,7 @@ import re
|
|
3
3
|
from contextvars import ContextVar
|
4
4
|
from functools import wraps
|
5
5
|
from time import time
|
6
|
-
from typing import Any, AsyncGenerator,
|
6
|
+
from typing import Any, AsyncGenerator, Type, cast
|
7
7
|
|
8
8
|
from openai._base_client import AsyncAPIClient, _AsyncStreamT
|
9
9
|
from openai._models import FinalRequestOptions
|
@@ -65,7 +65,6 @@ def init_openai_request_patch() -> None:
|
|
65
65
|
*,
|
66
66
|
stream: bool = False,
|
67
67
|
stream_cls: type[_AsyncStreamT] | None = None,
|
68
|
-
remaining_retries: Optional[int] = None,
|
69
68
|
) -> Any:
|
70
69
|
# we have patched the underlying request method so now need to figure out when to
|
71
70
|
# patch and when to stand down
|
@@ -88,7 +87,6 @@ def init_openai_request_patch() -> None:
|
|
88
87
|
options,
|
89
88
|
stream=stream,
|
90
89
|
stream_cls=stream_cls,
|
91
|
-
remaining_retries=remaining_retries,
|
92
90
|
)
|
93
91
|
|
94
92
|
setattr(AsyncAPIClient, "request", patched_request)
|
inspect_ai/agent/_handoff.py
CHANGED
@@ -57,7 +57,9 @@ def handoff(
|
|
57
57
|
tool_info = agent_tool_info(agent, description, **agent_kwargs)
|
58
58
|
|
59
59
|
# AgentTool calls will be intercepted by execute_tools
|
60
|
-
agent_tool = AgentTool(
|
60
|
+
agent_tool = AgentTool(
|
61
|
+
agent, tool_info.name, input_filter, output_filter, limits, **agent_kwargs
|
62
|
+
)
|
61
63
|
tool_name = tool_name or f"transfer_to_{tool_info.name}"
|
62
64
|
set_registry_info(agent_tool, RegistryInfo(type="tool", name=tool_name))
|
63
65
|
set_tool_description(
|
@@ -75,12 +77,14 @@ class AgentTool(Tool):
|
|
75
77
|
def __init__(
|
76
78
|
self,
|
77
79
|
agent: Agent,
|
80
|
+
name: str,
|
78
81
|
input_filter: MessageFilter | None = None,
|
79
82
|
output_filter: MessageFilter | None = None,
|
80
83
|
limits: list[Limit] = [],
|
81
84
|
**kwargs: Any,
|
82
85
|
):
|
83
86
|
self.agent = agent
|
87
|
+
self.name = name
|
84
88
|
self.input_filter = input_filter
|
85
89
|
self.output_filter = output_filter
|
86
90
|
self.limits = limits
|
inspect_ai/agent/_react.py
CHANGED
@@ -195,9 +195,10 @@ def react(
|
|
195
195
|
answer = submission(messages)
|
196
196
|
if answer is not None:
|
197
197
|
# set the output to the answer for scoring
|
198
|
-
|
199
|
-
|
200
|
-
|
198
|
+
if submit.answer_only:
|
199
|
+
state.output.completion = answer
|
200
|
+
else:
|
201
|
+
state.output.completion = f"{state.output.completion}{submit.answer_delimiter}{answer}".strip()
|
201
202
|
|
202
203
|
# exit if we are at max_attempts
|
203
204
|
attempt_count += 1
|
inspect_ai/agent/_run.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1
1
|
from copy import copy
|
2
2
|
from typing import Any
|
3
3
|
|
4
|
+
from inspect_ai._util.registry import registry_unqualified_name
|
4
5
|
from inspect_ai.model._chat_message import ChatMessage, ChatMessageUser
|
5
6
|
from inspect_ai.util._limit import Limit, apply_limits
|
7
|
+
from inspect_ai.util._span import span
|
6
8
|
|
7
9
|
from ._agent import Agent, AgentState
|
8
10
|
|
@@ -52,4 +54,7 @@ async def run(
|
|
52
54
|
|
53
55
|
# run the agent with limits
|
54
56
|
with apply_limits(limits):
|
55
|
-
|
57
|
+
# run the agent
|
58
|
+
agent_name = registry_unqualified_name(agent)
|
59
|
+
async with span(name=agent_name, type="agent"):
|
60
|
+
return await agent(state, **agent_kwargs)
|
inspect_ai/agent/_types.py
CHANGED
@@ -96,3 +96,12 @@ class AgentSubmit(NamedTuple):
|
|
96
96
|
|
97
97
|
The tool should return the `answer` provided to it for scoring.
|
98
98
|
"""
|
99
|
+
|
100
|
+
answer_only: bool = False
|
101
|
+
"""Set the completion to only the answer provided by the submit tool.
|
102
|
+
|
103
|
+
By default, the answer is appended (with `answer_delimiter`) to whatever
|
104
|
+
other content the model generated along with the call to `submit()`."""
|
105
|
+
|
106
|
+
answer_delimiter: str = "\n\n"
|
107
|
+
"""Delimter used when appending submit tool answer to other content the model generated along with the call to `submit()`."""
|
File without changes
|
@@ -0,0 +1,57 @@
|
|
1
|
+
from ._dataframe.columns import (
|
2
|
+
Column,
|
3
|
+
ColumnError,
|
4
|
+
ColumnErrors,
|
5
|
+
ColumnType,
|
6
|
+
)
|
7
|
+
from ._dataframe.evals.columns import (
|
8
|
+
EvalColumn,
|
9
|
+
EvalColumns,
|
10
|
+
EvalConfig,
|
11
|
+
EvalInfo,
|
12
|
+
EvalModel,
|
13
|
+
EvalResults,
|
14
|
+
EvalScores,
|
15
|
+
EvalTask,
|
16
|
+
)
|
17
|
+
from ._dataframe.evals.table import evals_df
|
18
|
+
from ._dataframe.events.columns import EventColumn
|
19
|
+
from ._dataframe.events.table import events_df
|
20
|
+
from ._dataframe.messages.columns import (
|
21
|
+
MessageColumn,
|
22
|
+
MessageColumns,
|
23
|
+
MessageContent,
|
24
|
+
MessageToolCalls,
|
25
|
+
)
|
26
|
+
from ._dataframe.messages.table import MessageFilter, messages_df
|
27
|
+
from ._dataframe.samples.columns import SampleColumn, SampleMessages, SampleSummary
|
28
|
+
from ._dataframe.samples.table import samples_df
|
29
|
+
|
30
|
+
__all__ = [
|
31
|
+
"evals_df",
|
32
|
+
"EvalColumn",
|
33
|
+
"EvalColumns",
|
34
|
+
"EvalInfo",
|
35
|
+
"EvalTask",
|
36
|
+
"EvalModel",
|
37
|
+
"EvalColumns",
|
38
|
+
"EvalConfig",
|
39
|
+
"EvalResults",
|
40
|
+
"EvalScores",
|
41
|
+
"samples_df",
|
42
|
+
"SampleColumn",
|
43
|
+
"SampleSummary",
|
44
|
+
"SampleMessages",
|
45
|
+
"messages_df",
|
46
|
+
"MessageColumn",
|
47
|
+
"MessageContent",
|
48
|
+
"MessageToolCalls",
|
49
|
+
"MessageColumns",
|
50
|
+
"MessageFilter",
|
51
|
+
"events_df",
|
52
|
+
"EventColumn",
|
53
|
+
"Column",
|
54
|
+
"ColumnType",
|
55
|
+
"ColumnError",
|
56
|
+
"ColumnErrors",
|
57
|
+
]
|
File without changes
|
@@ -0,0 +1,145 @@
|
|
1
|
+
import abc
|
2
|
+
from dataclasses import KW_ONLY, dataclass
|
3
|
+
from datetime import date, datetime, time
|
4
|
+
from typing import Any, Callable, Mapping, Type, TypeAlias
|
5
|
+
|
6
|
+
from jsonpath_ng import JSONPath # type: ignore
|
7
|
+
from jsonpath_ng.ext import parse # type: ignore
|
8
|
+
from pydantic import JsonValue
|
9
|
+
|
10
|
+
from .validate import jsonpath_in_schema
|
11
|
+
|
12
|
+
ColumnType: TypeAlias = int | float | bool | str | date | time | datetime | None
|
13
|
+
"""Valid types for columns.
|
14
|
+
|
15
|
+
Values of `list` and `dict` are converted into column values as JSON `str`.
|
16
|
+
"""
|
17
|
+
|
18
|
+
|
19
|
+
class Column(abc.ABC):
|
20
|
+
"""
|
21
|
+
Specification for importing a column into a dataframe.
|
22
|
+
|
23
|
+
Extract columns from an `EvalLog` path either using [JSONPath](https://github.com/h2non/jsonpath-ng) expressions
|
24
|
+
or a function that takes `EvalLog` and returns a value.
|
25
|
+
|
26
|
+
By default, columns are not required, pass `required=True` to make them required. Non-required
|
27
|
+
columns are extracted as `None`, provide a `default` to yield an alternate value.
|
28
|
+
|
29
|
+
The `type` option serves as both a validation check and a directive to attempt to coerce the
|
30
|
+
data into the specified `type`. Coercion from `str` to other types is done after interpreting
|
31
|
+
the string using YAML (e.g. `"true"` -> `True`).
|
32
|
+
|
33
|
+
The `value` function provides an additional hook for transformation of the value read
|
34
|
+
from the log before it is realized as a column (e.g. list to a comma-separated string).
|
35
|
+
|
36
|
+
The `root` option indicates which root eval log context the columns select from.
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
name: str,
|
42
|
+
*,
|
43
|
+
path: str | JSONPath | None,
|
44
|
+
required: bool = False,
|
45
|
+
default: JsonValue | None = None,
|
46
|
+
type: Type[ColumnType] | None = None,
|
47
|
+
value: Callable[[JsonValue], JsonValue] | None = None,
|
48
|
+
) -> None:
|
49
|
+
self._name = name
|
50
|
+
self._path: str | JSONPath | None = path
|
51
|
+
self._required = required
|
52
|
+
self._default = default
|
53
|
+
self._type = type
|
54
|
+
self._value = value
|
55
|
+
self._validated: bool | None = None
|
56
|
+
|
57
|
+
@property
|
58
|
+
def name(self) -> str:
|
59
|
+
"""Column name."""
|
60
|
+
return self._name
|
61
|
+
|
62
|
+
@property
|
63
|
+
def path(self) -> JSONPath | None:
|
64
|
+
"""Path to column in `EvalLog`"""
|
65
|
+
if isinstance(self._path, str):
|
66
|
+
self._path = parse(self._path)
|
67
|
+
return self._path
|
68
|
+
|
69
|
+
@property
|
70
|
+
def required(self) -> bool:
|
71
|
+
"""Is the column required? (error is raised if required columns aren't found)."""
|
72
|
+
return self._required
|
73
|
+
|
74
|
+
@property
|
75
|
+
def default(self) -> JsonValue | None:
|
76
|
+
"""Default value for column when it is read from the log as `None`."""
|
77
|
+
return self._default
|
78
|
+
|
79
|
+
@property
|
80
|
+
def type(self) -> Type[ColumnType] | None:
|
81
|
+
"""Column type (import will attempt to coerce to the specified type)."""
|
82
|
+
return self._type
|
83
|
+
|
84
|
+
def value(self, x: JsonValue) -> JsonValue:
|
85
|
+
"""Convert extracted value into a column value (defaults to identity function).
|
86
|
+
|
87
|
+
Params:
|
88
|
+
x: Value to convert.
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
Converted value.
|
92
|
+
"""
|
93
|
+
if self._value:
|
94
|
+
return self._value(x)
|
95
|
+
else:
|
96
|
+
return x
|
97
|
+
|
98
|
+
def validate_path(self) -> bool:
|
99
|
+
if self.path is not None:
|
100
|
+
if self._validated is None:
|
101
|
+
schema = self.path_schema()
|
102
|
+
self._validated = (
|
103
|
+
jsonpath_in_schema(self.path, schema) if schema else True
|
104
|
+
)
|
105
|
+
return self._validated
|
106
|
+
else:
|
107
|
+
return True
|
108
|
+
|
109
|
+
@abc.abstractmethod
|
110
|
+
def path_schema(self) -> Mapping[str, Any] | None: ...
|
111
|
+
|
112
|
+
|
113
|
+
@dataclass
|
114
|
+
class ColumnError:
|
115
|
+
"""Error which occurred parsing a column."""
|
116
|
+
|
117
|
+
column: str
|
118
|
+
"""Target column name."""
|
119
|
+
|
120
|
+
_: KW_ONLY
|
121
|
+
|
122
|
+
path: str | None
|
123
|
+
"""Path to select column value. """
|
124
|
+
|
125
|
+
message: str
|
126
|
+
"""Error message."""
|
127
|
+
|
128
|
+
def __str__(self) -> str:
|
129
|
+
msg = f"Error reading column '{self.column}'"
|
130
|
+
if self.path:
|
131
|
+
msg = f"{msg} from path '{self.path}'"
|
132
|
+
return f"{msg}: {self.message}"
|
133
|
+
|
134
|
+
|
135
|
+
class ColumnErrors(dict[str, list[ColumnError]]):
|
136
|
+
"""Dictionary of column errors keyed by log file."""
|
137
|
+
|
138
|
+
def __str__(self) -> str:
|
139
|
+
lines: list[str] = [""]
|
140
|
+
for file, errors in self.items():
|
141
|
+
lines.append(file)
|
142
|
+
for error in errors:
|
143
|
+
lines.append(f" - {error}")
|
144
|
+
lines.append("")
|
145
|
+
return "\n".join(lines)
|
File without changes
|
@@ -0,0 +1,132 @@
|
|
1
|
+
from datetime import datetime
|
2
|
+
from typing import Any, Callable, Mapping, Type
|
3
|
+
|
4
|
+
from jsonpath_ng import JSONPath # type: ignore
|
5
|
+
from pydantic import JsonValue
|
6
|
+
from typing_extensions import override
|
7
|
+
|
8
|
+
from inspect_ai.log._log import EvalLog
|
9
|
+
|
10
|
+
from ..columns import Column, ColumnType
|
11
|
+
from ..extract import list_as_str
|
12
|
+
from ..validate import resolved_schema
|
13
|
+
from .extract import eval_log_location, eval_log_scores_dict
|
14
|
+
|
15
|
+
|
16
|
+
class EvalColumn(Column):
|
17
|
+
"""Column which maps to `EvalLog`."""
|
18
|
+
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
name: str,
|
22
|
+
*,
|
23
|
+
path: str | JSONPath | Callable[[EvalLog], JsonValue],
|
24
|
+
required: bool = False,
|
25
|
+
default: JsonValue | None = None,
|
26
|
+
type: Type[ColumnType] | None = None,
|
27
|
+
value: Callable[[JsonValue], JsonValue] | None = None,
|
28
|
+
) -> None:
|
29
|
+
super().__init__(
|
30
|
+
name=name,
|
31
|
+
path=path if not callable(path) else None,
|
32
|
+
required=required,
|
33
|
+
default=default,
|
34
|
+
type=type,
|
35
|
+
value=value,
|
36
|
+
)
|
37
|
+
self._extract_eval = path if callable(path) else None
|
38
|
+
|
39
|
+
@override
|
40
|
+
def path_schema(self) -> Mapping[str, Any]:
|
41
|
+
return self.schema
|
42
|
+
|
43
|
+
schema = resolved_schema(EvalLog)
|
44
|
+
|
45
|
+
|
46
|
+
EvalId: list[Column] = [
|
47
|
+
EvalColumn("eval_id", path="eval.eval_id", required=True),
|
48
|
+
]
|
49
|
+
"""Eval id column."""
|
50
|
+
|
51
|
+
EvalInfo: list[Column] = [
|
52
|
+
EvalColumn("run_id", path="eval.run_id", required=True),
|
53
|
+
EvalColumn("task_id", path="eval.task_id", required=True),
|
54
|
+
EvalColumn("log", path=eval_log_location),
|
55
|
+
EvalColumn("created", path="eval.created", type=datetime, required=True),
|
56
|
+
EvalColumn("tags", path="eval.tags", default="", value=list_as_str),
|
57
|
+
EvalColumn("git_origin", path="eval.revision.origin"),
|
58
|
+
EvalColumn("git_commit", path="eval.revision.commit"),
|
59
|
+
EvalColumn("packages", path="eval.packages"),
|
60
|
+
EvalColumn("metadata", path="eval.metadata"),
|
61
|
+
]
|
62
|
+
"""Eval basic information columns."""
|
63
|
+
|
64
|
+
EvalTask: list[Column] = [
|
65
|
+
EvalColumn("task_name", path="eval.task", required=True),
|
66
|
+
EvalColumn("task_version", path="eval.task_version", required=True),
|
67
|
+
EvalColumn("task_file", path="eval.task_file"),
|
68
|
+
EvalColumn("task_attribs", path="eval.task_attribs"),
|
69
|
+
EvalColumn("task_arg_*", path="eval.task_args"),
|
70
|
+
EvalColumn("solver", path="eval.solver"),
|
71
|
+
EvalColumn("solver_args", path="eval.solver_args"),
|
72
|
+
EvalColumn("sandbox_type", path="eval.sandbox.type"),
|
73
|
+
EvalColumn("sandbox_config", path="eval.sandbox.config"),
|
74
|
+
]
|
75
|
+
"""Eval task configuration columns."""
|
76
|
+
|
77
|
+
EvalModel: list[Column] = [
|
78
|
+
EvalColumn("model", path="eval.model", required=True),
|
79
|
+
EvalColumn("model_base_url", path="eval.model_base_url"),
|
80
|
+
EvalColumn("model_args", path="eval.model_base_url"),
|
81
|
+
EvalColumn("model_generate_config", path="eval.model_generate_config"),
|
82
|
+
EvalColumn("model_roles", path="eval.model_roles"),
|
83
|
+
]
|
84
|
+
"""Eval model columns."""
|
85
|
+
|
86
|
+
EvalDataset: list[Column] = [
|
87
|
+
EvalColumn("dataset_name", path="eval.dataset.name"),
|
88
|
+
EvalColumn("dataset_location", path="eval.dataset.location"),
|
89
|
+
EvalColumn("dataset_samples", path="eval.dataset.samples"),
|
90
|
+
EvalColumn("dataset_sample_ids", path="eval.dataset.sample_ids"),
|
91
|
+
EvalColumn("dataset_shuffled", path="eval.dataset.shuffled"),
|
92
|
+
]
|
93
|
+
"""Eval dataset columns."""
|
94
|
+
|
95
|
+
EvalConfig: list[Column] = [
|
96
|
+
EvalColumn("epochs", path="eval.config.epochs"),
|
97
|
+
EvalColumn("epochs_reducer", path="eval.config.epochs_reducer"),
|
98
|
+
EvalColumn("approval", path="eval.config.approval"),
|
99
|
+
EvalColumn("message_limit", path="eval.config.message_limit"),
|
100
|
+
EvalColumn("token_limit", path="eval.config.token_limit"),
|
101
|
+
EvalColumn("time_limit", path="eval.config.time_limit"),
|
102
|
+
EvalColumn("working_limit", path="eval.config.working_limit"),
|
103
|
+
]
|
104
|
+
"""Eval configuration columns."""
|
105
|
+
|
106
|
+
EvalResults: list[Column] = [
|
107
|
+
EvalColumn("status", path="status", required=True),
|
108
|
+
EvalColumn("error_message", path="error.message"),
|
109
|
+
EvalColumn("error_traceback", path="error.traceback"),
|
110
|
+
EvalColumn("total_samples", path="results.total_samples"),
|
111
|
+
EvalColumn("completed_samples", path="results.completed_samples"),
|
112
|
+
EvalColumn("score_headline_name", path="results.scores[0].scorer"),
|
113
|
+
EvalColumn("score_headline_metric", path="results.scores[0].metrics.*.name"),
|
114
|
+
EvalColumn("score_headline_value", path="results.scores[0].metrics.*.value"),
|
115
|
+
]
|
116
|
+
"""Eval results columns."""
|
117
|
+
|
118
|
+
EvalScores: list[Column] = [
|
119
|
+
EvalColumn("score_*_*", path=eval_log_scores_dict),
|
120
|
+
]
|
121
|
+
"""Eval scores (one score/metric per-columns)."""
|
122
|
+
|
123
|
+
EvalColumns: list[Column] = (
|
124
|
+
EvalInfo
|
125
|
+
+ EvalTask
|
126
|
+
+ EvalModel
|
127
|
+
+ EvalDataset
|
128
|
+
+ EvalConfig
|
129
|
+
+ EvalResults
|
130
|
+
+ EvalScores
|
131
|
+
)
|
132
|
+
"""Default columns to import for `evals_df()`."""
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from inspect_ai._util.path import native_path
|
2
|
+
from inspect_ai.log._log import EvalLog
|
3
|
+
|
4
|
+
|
5
|
+
def eval_log_location(log: EvalLog) -> str:
|
6
|
+
return native_path(log.location)
|
7
|
+
|
8
|
+
|
9
|
+
def eval_log_scores_dict(
|
10
|
+
log: EvalLog,
|
11
|
+
) -> list[dict[str, dict[str, int | float]]] | None:
|
12
|
+
if log.results is not None:
|
13
|
+
metrics = [
|
14
|
+
{
|
15
|
+
score.name: {
|
16
|
+
metric.name: metric.value for metric in score.metrics.values()
|
17
|
+
}
|
18
|
+
}
|
19
|
+
for score in log.results.scores
|
20
|
+
]
|
21
|
+
return metrics
|
22
|
+
else:
|
23
|
+
return None
|
@@ -0,0 +1,140 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Literal, overload
|
4
|
+
|
5
|
+
from inspect_ai._display import display
|
6
|
+
from inspect_ai._util.path import pretty_path
|
7
|
+
from inspect_ai.log._file import (
|
8
|
+
read_eval_log,
|
9
|
+
)
|
10
|
+
|
11
|
+
from ..columns import Column, ColumnErrors, ColumnType
|
12
|
+
from ..record import import_record, resolve_duplicate_columns
|
13
|
+
from ..util import (
|
14
|
+
LogPaths,
|
15
|
+
add_unreferenced_columns,
|
16
|
+
records_to_pandas,
|
17
|
+
resolve_columns,
|
18
|
+
resolve_logs,
|
19
|
+
verify_prerequisites,
|
20
|
+
)
|
21
|
+
from .columns import EvalColumns, EvalId
|
22
|
+
|
23
|
+
if TYPE_CHECKING:
|
24
|
+
import pandas as pd
|
25
|
+
|
26
|
+
EVAL_ID = "eval_id"
|
27
|
+
EVAL_SUFFIX = "_eval"
|
28
|
+
|
29
|
+
|
30
|
+
@overload
|
31
|
+
def evals_df(
|
32
|
+
logs: LogPaths,
|
33
|
+
columns: list[Column] = EvalColumns,
|
34
|
+
recursive: bool = True,
|
35
|
+
reverse: bool = False,
|
36
|
+
strict: Literal[True] = True,
|
37
|
+
) -> "pd.DataFrame": ...
|
38
|
+
|
39
|
+
|
40
|
+
@overload
|
41
|
+
def evals_df(
|
42
|
+
logs: LogPaths,
|
43
|
+
columns: list[Column] = EvalColumns,
|
44
|
+
recursive: bool = True,
|
45
|
+
reverse: bool = False,
|
46
|
+
strict: Literal[False] = False,
|
47
|
+
) -> tuple["pd.DataFrame", ColumnErrors]: ...
|
48
|
+
|
49
|
+
|
50
|
+
def evals_df(
|
51
|
+
logs: LogPaths,
|
52
|
+
columns: list[Column] = EvalColumns,
|
53
|
+
recursive: bool = True,
|
54
|
+
reverse: bool = False,
|
55
|
+
strict: bool = True,
|
56
|
+
) -> "pd.DataFrame" | tuple["pd.DataFrame", ColumnErrors]:
|
57
|
+
"""Read a dataframe containing evals.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
logs: One or more paths to log files or log directories.
|
61
|
+
columns: Specification for what columns to read from log files.
|
62
|
+
recursive: Include recursive contents of directories (defaults to `True`)
|
63
|
+
reverse: Reverse the order of the dataframe (by default, items
|
64
|
+
are ordered from oldest to newest).
|
65
|
+
strict: Raise import errors immediately. Defaults to `True`.
|
66
|
+
If `False` then a tuple of `DataFrame` and errors is returned.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
For `strict`, a Pandas `DataFrame` with information for the specified logs.
|
70
|
+
For `strict=False`, a tuple of Pandas `DataFrame` and a dictionary of errors
|
71
|
+
encountered (by log file) during import.
|
72
|
+
"""
|
73
|
+
verify_prerequisites()
|
74
|
+
|
75
|
+
# resolve logs
|
76
|
+
log_paths = resolve_logs(logs, recursive=recursive, reverse=reverse)
|
77
|
+
|
78
|
+
# resolve duplicate columns
|
79
|
+
columns = resolve_duplicate_columns(columns)
|
80
|
+
|
81
|
+
# accumulate errors for strict=False
|
82
|
+
all_errors = ColumnErrors()
|
83
|
+
|
84
|
+
# ensure eval_id
|
85
|
+
ensure_eval_id(columns)
|
86
|
+
|
87
|
+
# read logs
|
88
|
+
records: list[dict[str, ColumnType]] = []
|
89
|
+
with display().progress(total=len(log_paths)) as p:
|
90
|
+
for log_path in log_paths:
|
91
|
+
log = read_eval_log(log_path, header_only=True)
|
92
|
+
if strict:
|
93
|
+
record = import_record(log, columns, strict=True)
|
94
|
+
else:
|
95
|
+
record, errors = import_record(log, columns, strict=False)
|
96
|
+
all_errors[pretty_path(log_path)] = errors
|
97
|
+
records.append(record)
|
98
|
+
|
99
|
+
p.update()
|
100
|
+
|
101
|
+
# return table (+errors if strict=False)
|
102
|
+
evals_table = records_to_pandas(records)
|
103
|
+
evals_table = reorder_evals_df_columns(evals_table, columns)
|
104
|
+
|
105
|
+
if strict:
|
106
|
+
return evals_table
|
107
|
+
else:
|
108
|
+
return evals_table, all_errors
|
109
|
+
|
110
|
+
|
111
|
+
def ensure_eval_id(columns: list[Column]) -> None:
|
112
|
+
if not any([column.name == EVAL_ID for column in columns]):
|
113
|
+
columns.extend(EvalId)
|
114
|
+
|
115
|
+
|
116
|
+
def reorder_evals_df_columns(
|
117
|
+
df: "pd.DataFrame", eval_columns: list[Column]
|
118
|
+
) -> "pd.DataFrame":
|
119
|
+
actual_columns = list(df.columns)
|
120
|
+
ordered_columns: list[str] = []
|
121
|
+
|
122
|
+
# eval_id first
|
123
|
+
if EVAL_ID in actual_columns:
|
124
|
+
ordered_columns.append(EVAL_ID)
|
125
|
+
|
126
|
+
# eval columns
|
127
|
+
for col in eval_columns:
|
128
|
+
col_pattern = col.name
|
129
|
+
if col_pattern == EVAL_ID:
|
130
|
+
continue # Already handled
|
131
|
+
|
132
|
+
ordered_columns.extend(
|
133
|
+
resolve_columns(col_pattern, EVAL_SUFFIX, actual_columns, ordered_columns)
|
134
|
+
)
|
135
|
+
|
136
|
+
# add any unreferenced columns
|
137
|
+
ordered_columns = add_unreferenced_columns(actual_columns, ordered_columns)
|
138
|
+
|
139
|
+
# reorder the DataFrame
|
140
|
+
return df[ordered_columns]
|
File without changes
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from typing import Any, Callable, Mapping, Type
|
2
|
+
|
3
|
+
from jsonpath_ng import JSONPath # type: ignore
|
4
|
+
from pydantic import JsonValue
|
5
|
+
from typing_extensions import override
|
6
|
+
|
7
|
+
from inspect_ai.log._transcript import Event
|
8
|
+
|
9
|
+
from ..columns import Column, ColumnType
|
10
|
+
|
11
|
+
|
12
|
+
class EventColumn(Column):
|
13
|
+
"""Column which maps to `Event`."""
|
14
|
+
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
name: str,
|
18
|
+
*,
|
19
|
+
path: str | JSONPath | Callable[[Event], JsonValue],
|
20
|
+
required: bool = False,
|
21
|
+
default: JsonValue | None = None,
|
22
|
+
type: Type[ColumnType] | None = None,
|
23
|
+
value: Callable[[JsonValue], JsonValue] | None = None,
|
24
|
+
) -> None:
|
25
|
+
super().__init__(
|
26
|
+
name=name,
|
27
|
+
path=path if not callable(path) else None,
|
28
|
+
required=required,
|
29
|
+
default=default,
|
30
|
+
type=type,
|
31
|
+
value=value,
|
32
|
+
)
|
33
|
+
self._extract_event = path if callable(path) else None
|
34
|
+
|
35
|
+
@override
|
36
|
+
def path_schema(self) -> Mapping[str, Any] | None:
|
37
|
+
return None
|