inspect-ai 0.3.93__py3-none-any.whl → 0.3.95__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_display/textual/widgets/samples.py +3 -3
- inspect_ai/_display/textual/widgets/transcript.py +3 -29
- inspect_ai/_eval/loader.py +1 -1
- inspect_ai/_eval/task/run.py +21 -12
- inspect_ai/_util/answer.py +26 -0
- inspect_ai/_util/constants.py +0 -1
- inspect_ai/_util/exception.py +4 -0
- inspect_ai/_util/hash.py +39 -0
- inspect_ai/_util/local_server.py +51 -21
- inspect_ai/_util/path.py +22 -0
- inspect_ai/_util/trace.py +1 -1
- inspect_ai/_util/working.py +4 -0
- inspect_ai/_view/www/dist/assets/index.css +23 -22
- inspect_ai/_view/www/dist/assets/index.js +517 -204
- inspect_ai/_view/www/log-schema.json +375 -0
- inspect_ai/_view/www/package.json +1 -1
- inspect_ai/_view/www/src/@types/log.d.ts +90 -12
- inspect_ai/_view/www/src/app/log-view/navbar/SecondaryBar.tsx +2 -2
- inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +1 -4
- inspect_ai/_view/www/src/app/samples/SamplesTools.tsx +3 -13
- inspect_ai/_view/www/src/app/samples/sample-tools/SelectScorer.tsx +45 -48
- inspect_ai/_view/www/src/app/samples/sample-tools/filters.ts +16 -15
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/SampleFilter.tsx +47 -75
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/completions.ts +9 -9
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
- inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
- inspect_ai/_view/www/src/app/types.ts +12 -2
- inspect_ai/_view/www/src/components/ExpandablePanel.module.css +1 -1
- inspect_ai/_view/www/src/components/ExpandablePanel.tsx +5 -5
- inspect_ai/_view/www/src/state/hooks.ts +19 -3
- inspect_ai/_view/www/src/state/logSlice.ts +23 -5
- inspect_ai/_view/www/yarn.lock +9 -9
- inspect_ai/agent/_as_solver.py +3 -1
- inspect_ai/agent/_as_tool.py +6 -4
- inspect_ai/agent/_bridge/patch.py +1 -3
- inspect_ai/agent/_handoff.py +5 -1
- inspect_ai/agent/_react.py +4 -3
- inspect_ai/agent/_run.py +6 -1
- inspect_ai/agent/_types.py +9 -0
- inspect_ai/analysis/__init__.py +0 -0
- inspect_ai/analysis/beta/__init__.py +57 -0
- inspect_ai/analysis/beta/_dataframe/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/columns.py +145 -0
- inspect_ai/analysis/beta/_dataframe/evals/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/evals/columns.py +132 -0
- inspect_ai/analysis/beta/_dataframe/evals/extract.py +23 -0
- inspect_ai/analysis/beta/_dataframe/evals/table.py +140 -0
- inspect_ai/analysis/beta/_dataframe/events/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/events/columns.py +37 -0
- inspect_ai/analysis/beta/_dataframe/events/table.py +14 -0
- inspect_ai/analysis/beta/_dataframe/extract.py +54 -0
- inspect_ai/analysis/beta/_dataframe/messages/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/messages/columns.py +60 -0
- inspect_ai/analysis/beta/_dataframe/messages/extract.py +21 -0
- inspect_ai/analysis/beta/_dataframe/messages/table.py +87 -0
- inspect_ai/analysis/beta/_dataframe/record.py +377 -0
- inspect_ai/analysis/beta/_dataframe/samples/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/samples/columns.py +73 -0
- inspect_ai/analysis/beta/_dataframe/samples/extract.py +82 -0
- inspect_ai/analysis/beta/_dataframe/samples/table.py +329 -0
- inspect_ai/analysis/beta/_dataframe/util.py +157 -0
- inspect_ai/analysis/beta/_dataframe/validate.py +171 -0
- inspect_ai/dataset/_dataset.py +6 -3
- inspect_ai/log/__init__.py +10 -0
- inspect_ai/log/_convert.py +4 -9
- inspect_ai/log/_file.py +1 -1
- inspect_ai/log/_log.py +21 -1
- inspect_ai/log/_samples.py +14 -17
- inspect_ai/log/_transcript.py +77 -35
- inspect_ai/log/_tree.py +118 -0
- inspect_ai/model/_call_tools.py +44 -35
- inspect_ai/model/_model.py +51 -44
- inspect_ai/model/_openai_responses.py +17 -18
- inspect_ai/model/_providers/anthropic.py +30 -5
- inspect_ai/model/_providers/hf.py +27 -1
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/sglang.py +8 -2
- inspect_ai/model/_providers/vllm.py +6 -2
- inspect_ai/scorer/_choice.py +1 -2
- inspect_ai/solver/_chain.py +1 -1
- inspect_ai/solver/_fork.py +1 -1
- inspect_ai/solver/_multiple_choice.py +9 -23
- inspect_ai/solver/_plan.py +2 -2
- inspect_ai/solver/_task_state.py +7 -3
- inspect_ai/solver/_transcript.py +6 -7
- inspect_ai/tool/_mcp/_context.py +3 -5
- inspect_ai/tool/_mcp/_mcp.py +6 -5
- inspect_ai/tool/_mcp/server.py +1 -1
- inspect_ai/tool/_tools/_execute.py +4 -1
- inspect_ai/tool/_tools/_think.py +1 -1
- inspect_ai/tool/_tools/_web_search/__init__.py +3 -0
- inspect_ai/tool/_tools/{_web_search.py → _web_search/_google.py} +56 -103
- inspect_ai/tool/_tools/_web_search/_tavily.py +77 -0
- inspect_ai/tool/_tools/_web_search/_web_search.py +85 -0
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_anyio.py +11 -0
- inspect_ai/util/_collect.py +50 -0
- inspect_ai/util/_sandbox/events.py +3 -2
- inspect_ai/util/_span.py +58 -0
- inspect_ai/util/_subtask.py +27 -42
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/METADATA +8 -1
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/RECORD +114 -82
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/WHEEL +1 -1
- inspect_ai/_display/core/group.py +0 -79
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,171 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from logging import getLogger
|
4
|
+
from typing import Any, Iterator, Mapping, Type
|
5
|
+
|
6
|
+
import jsonref # type: ignore
|
7
|
+
from jsonpath_ng import Fields, Index, JSONPath, Slice, Where, WhereNot # type: ignore
|
8
|
+
from jsonpath_ng.ext.filter import Filter # type: ignore
|
9
|
+
from pydantic import BaseModel
|
10
|
+
|
11
|
+
logger = getLogger(__name__)
|
12
|
+
|
13
|
+
Schema = Mapping[str, Any]
|
14
|
+
|
15
|
+
|
16
|
+
def resolved_schema(model: Type[BaseModel]) -> Schema:
|
17
|
+
schema_dict = model.model_json_schema()
|
18
|
+
base = "file:///memory/inspect_schema.json"
|
19
|
+
schema: Schema = jsonref.replace_refs(
|
20
|
+
schema_dict, base_uri=base, jsonschema=True, proxies=False
|
21
|
+
)
|
22
|
+
return schema
|
23
|
+
|
24
|
+
|
25
|
+
def jsonpath_in_schema(expr: JSONPath, schema: Schema) -> bool:
|
26
|
+
# don't validate unsupported constructs
|
27
|
+
if find_unsupported(expr):
|
28
|
+
return True
|
29
|
+
|
30
|
+
def descend(sch: Schema, tok: str | int | None) -> list[Schema]:
|
31
|
+
# First, branch through anyOf/oneOf/allOf
|
32
|
+
outs: list[Schema] = []
|
33
|
+
for branch in _expand_union(sch):
|
34
|
+
outs.extend(descend_concrete(branch, tok))
|
35
|
+
return outs
|
36
|
+
|
37
|
+
def descend_concrete(sch: Schema, tok: str | int | None) -> list[Schema]:
|
38
|
+
# totally open object – accept any child
|
39
|
+
if sch == {}:
|
40
|
+
return [{}] # stay alive, accept any key
|
41
|
+
|
42
|
+
outs: list[Schema] = []
|
43
|
+
|
44
|
+
def open_dict(node: Schema) -> None:
|
45
|
+
"""Append the schema that governs unknown keys.
|
46
|
+
|
47
|
+
- None / missing -> open object -> {}
|
48
|
+
- True -> open object -> {}
|
49
|
+
- Mapping -> that mapping (could be {} or a real subschema)
|
50
|
+
- False -> closed object -> (do nothing)
|
51
|
+
"""
|
52
|
+
if "additionalProperties" not in node:
|
53
|
+
if not node.get("properties"):
|
54
|
+
outs.append({})
|
55
|
+
else:
|
56
|
+
ap = node["additionalProperties"]
|
57
|
+
if ap is True:
|
58
|
+
outs.append({})
|
59
|
+
elif isinstance(ap, Mapping): # {} or {...}
|
60
|
+
outs.append(ap)
|
61
|
+
# ap is False -> closed dict -> ignore
|
62
|
+
|
63
|
+
# Wildcard -----------------------------------------------------------
|
64
|
+
if tok is None:
|
65
|
+
if "properties" in sch:
|
66
|
+
outs.extend(sch["properties"].values())
|
67
|
+
if "object" in _types(sch):
|
68
|
+
open_dict(sch)
|
69
|
+
if "array" in _types(sch) and "items" in sch:
|
70
|
+
outs.extend(_normalize_items(sch["items"]))
|
71
|
+
return outs
|
72
|
+
|
73
|
+
# Property access ----------------------------------------------------
|
74
|
+
if isinstance(tok, str):
|
75
|
+
if "properties" in sch and tok in sch["properties"]:
|
76
|
+
outs.append(sch["properties"][tok])
|
77
|
+
elif "additionalProperties" in sch: # PRESENCE, not truthiness
|
78
|
+
open_dict(sch)
|
79
|
+
elif "object" in _types(sch):
|
80
|
+
open_dict(sch)
|
81
|
+
|
82
|
+
# Array index --------------------------------------------------------
|
83
|
+
else: # tok is int or None from an Index node
|
84
|
+
if "array" in _types(sch) and "items" in sch:
|
85
|
+
outs.extend(_normalize_items(sch["items"], index=tok))
|
86
|
+
|
87
|
+
return outs
|
88
|
+
|
89
|
+
def _types(sch: Schema) -> set[str]:
|
90
|
+
t = sch.get("type")
|
91
|
+
return set(t) if isinstance(t, list) else {t} if t else set()
|
92
|
+
|
93
|
+
def _normalize_items(items: Any, index: int | None = None) -> list[Schema]:
|
94
|
+
if isinstance(items, list):
|
95
|
+
if index is None: # wildcard/slice
|
96
|
+
return items
|
97
|
+
if 0 <= index < len(items):
|
98
|
+
return [items[index]]
|
99
|
+
return []
|
100
|
+
if isinstance(items, Mapping):
|
101
|
+
return [items]
|
102
|
+
return []
|
103
|
+
|
104
|
+
states = [schema]
|
105
|
+
for tok in iter_tokens(expr):
|
106
|
+
next_states: list[Schema] = []
|
107
|
+
for st in states:
|
108
|
+
next_states.extend(descend(st, tok))
|
109
|
+
if not next_states: # nothing matched this segment
|
110
|
+
return False
|
111
|
+
states = next_states
|
112
|
+
return True # every segment found at least one schema
|
113
|
+
|
114
|
+
|
115
|
+
def iter_tokens(node: JSONPath) -> Iterator[str | int | None]:
|
116
|
+
"""Linearise a jsonpath-ng AST into a stream of tokens we care about."""
|
117
|
+
if hasattr(node, "left"): # Child, Descendants, etc.
|
118
|
+
yield from iter_tokens(node.left)
|
119
|
+
yield from iter_tokens(node.right)
|
120
|
+
elif isinstance(node, Fields):
|
121
|
+
yield from node.fields # e.g. ["foo"]
|
122
|
+
elif isinstance(node, Index):
|
123
|
+
yield node.index # 0 / -1 / None for wildcard
|
124
|
+
elif isinstance(node, Slice):
|
125
|
+
yield None # treat any slice as wildcard
|
126
|
+
|
127
|
+
|
128
|
+
COMBINATORS = ("anyOf", "oneOf", "allOf")
|
129
|
+
|
130
|
+
|
131
|
+
def _expand_union(sch: Schema) -> list[Schema]:
|
132
|
+
"""Return sch itself or the list of subschemas if it is a combinator."""
|
133
|
+
for key in COMBINATORS:
|
134
|
+
if key in sch:
|
135
|
+
subs: list[Schema] = []
|
136
|
+
for sub in sch[key]:
|
137
|
+
# a sub-schema might itself be an anyOf/oneOf/allOf
|
138
|
+
subs.extend(_expand_union(sub))
|
139
|
+
return subs
|
140
|
+
return [sch]
|
141
|
+
|
142
|
+
|
143
|
+
UNSUPPORTED: tuple[type[JSONPath], ...] = (
|
144
|
+
Filter, # [?foo > 0]
|
145
|
+
Where, # .foo[(@.bar < 42)]
|
146
|
+
WhereNot,
|
147
|
+
Slice, # [1:5] (wildcard “[*]” is Index/None, not Slice)
|
148
|
+
)
|
149
|
+
|
150
|
+
|
151
|
+
def find_unsupported(node: JSONPath) -> list[type[JSONPath]]:
|
152
|
+
"""Return a list of node types present in `node` that we do not validate."""
|
153
|
+
bad: list[type[JSONPath]] = []
|
154
|
+
stack: list[JSONPath] = [node]
|
155
|
+
while stack:
|
156
|
+
n = stack.pop()
|
157
|
+
if isinstance(n, UNSUPPORTED):
|
158
|
+
bad.append(type(n))
|
159
|
+
# Drill into children (jsonpath-ng uses .left / .right / .child attributes)
|
160
|
+
for attr in ("left", "right", "child", "expression"):
|
161
|
+
stack.extend(
|
162
|
+
[getattr(n, attr)]
|
163
|
+
if hasattr(n, attr) and isinstance(getattr(n, attr), JSONPath)
|
164
|
+
else []
|
165
|
+
)
|
166
|
+
# handle containers like Fields(fields=[...]) and Index(index=[...])
|
167
|
+
if hasattr(n, "__dict__"):
|
168
|
+
for v in n.__dict__.values():
|
169
|
+
if isinstance(v, list):
|
170
|
+
stack.extend(x for x in v if isinstance(x, JSONPath))
|
171
|
+
return bad
|
inspect_ai/dataset/_dataset.py
CHANGED
@@ -16,6 +16,7 @@ from typing import (
|
|
16
16
|
from pydantic import BaseModel, Field, ValidationError
|
17
17
|
from typing_extensions import override
|
18
18
|
|
19
|
+
from inspect_ai._util.answer import answer_character, answer_index
|
19
20
|
from inspect_ai.model import ChatMessage
|
20
21
|
from inspect_ai.util import SandboxEnvironmentSpec, SandboxEnvironmentType
|
21
22
|
from inspect_ai.util._sandbox.environment import resolve_sandbox_environment
|
@@ -328,7 +329,9 @@ class MemoryDataset(Dataset):
|
|
328
329
|
shuffled_choices = [sample.choices[i] for i in positions]
|
329
330
|
|
330
331
|
# Map of original position / target letter
|
331
|
-
position_map = {
|
332
|
+
position_map = {
|
333
|
+
i: answer_character(new_i) for new_i, i in enumerate(positions)
|
334
|
+
}
|
332
335
|
|
333
336
|
# Update to the shuffled choices and target
|
334
337
|
sample.choices = shuffled_choices
|
@@ -338,9 +341,9 @@ class MemoryDataset(Dataset):
|
|
338
341
|
self, target: str | list[str], position_map: dict[int, str]
|
339
342
|
) -> str | list[str]:
|
340
343
|
if isinstance(target, list):
|
341
|
-
return [position_map[
|
344
|
+
return [position_map[answer_index(t)] for t in target]
|
342
345
|
else:
|
343
|
-
return position_map[
|
346
|
+
return position_map[answer_index(target)]
|
344
347
|
|
345
348
|
@override
|
346
349
|
def sort(
|
inspect_ai/log/__init__.py
CHANGED
@@ -48,6 +48,8 @@ from ._transcript import (
|
|
48
48
|
SampleLimitEvent,
|
49
49
|
SandboxEvent,
|
50
50
|
ScoreEvent,
|
51
|
+
SpanBeginEvent,
|
52
|
+
SpanEndEvent,
|
51
53
|
StateEvent,
|
52
54
|
StepEvent,
|
53
55
|
StoreEvent,
|
@@ -56,6 +58,7 @@ from ._transcript import (
|
|
56
58
|
Transcript,
|
57
59
|
transcript,
|
58
60
|
)
|
61
|
+
from ._tree import EventNode, EventTree, SpanNode, event_sequence, event_tree
|
59
62
|
|
60
63
|
__all__ = [
|
61
64
|
"EvalConfig",
|
@@ -92,6 +95,8 @@ __all__ = [
|
|
92
95
|
"SampleLimitEvent",
|
93
96
|
"SandboxEvent",
|
94
97
|
"ScoreEvent",
|
98
|
+
"SpanBeginEvent",
|
99
|
+
"SpanEndEvent",
|
95
100
|
"StateEvent",
|
96
101
|
"StepEvent",
|
97
102
|
"StoreEvent",
|
@@ -111,4 +116,9 @@ __all__ = [
|
|
111
116
|
"write_log_dir_manifest",
|
112
117
|
"retryable_eval_logs",
|
113
118
|
"bundle_log_dir",
|
119
|
+
"event_tree",
|
120
|
+
"event_sequence",
|
121
|
+
"EventTree",
|
122
|
+
"EventNode",
|
123
|
+
"SpanNode",
|
114
124
|
]
|
inspect_ai/log/_convert.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2
2
|
from typing import Literal
|
3
3
|
|
4
4
|
from inspect_ai._util.error import PrerequisiteError
|
5
|
-
from inspect_ai._util.file import
|
5
|
+
from inspect_ai._util.file import exists, filesystem
|
6
6
|
from inspect_ai.log._file import (
|
7
7
|
log_files_from_ls,
|
8
8
|
read_eval_log,
|
@@ -66,14 +66,9 @@ def convert_eval_logs(
|
|
66
66
|
"Output file {output_file} already exists (use --overwrite to overwrite existing files)"
|
67
67
|
)
|
68
68
|
|
69
|
-
#
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
# otherwise do a full read/write
|
74
|
-
else:
|
75
|
-
log = read_eval_log(input_file)
|
76
|
-
write_eval_log(log, output_file)
|
69
|
+
# do a full read/write (normalized deprecated constructs and adds sample summaries)
|
70
|
+
log = read_eval_log(input_file)
|
71
|
+
write_eval_log(log, output_file)
|
77
72
|
|
78
73
|
if fs.info(path).type == "file":
|
79
74
|
convert_file(path)
|
inspect_ai/log/_file.py
CHANGED
@@ -524,7 +524,7 @@ def manifest_eval_log_name(info: EvalLogInfo, log_dir: str, sep: str) -> str:
|
|
524
524
|
|
525
525
|
def log_files_from_ls(
|
526
526
|
ls: list[FileInfo],
|
527
|
-
formats: list[Literal["eval", "json"]] | None,
|
527
|
+
formats: list[Literal["eval", "json"]] | None = None,
|
528
528
|
descending: bool = True,
|
529
529
|
) -> list[EvalLogInfo]:
|
530
530
|
extensions = [f".{format}" for format in (formats or ALL_LOG_FORMATS)]
|
inspect_ai/log/_log.py
CHANGED
@@ -17,9 +17,11 @@ from pydantic import (
|
|
17
17
|
)
|
18
18
|
from rich.console import Console, RenderableType
|
19
19
|
from rich.traceback import Traceback
|
20
|
+
from shortuuid import uuid
|
20
21
|
|
21
|
-
from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH, PKG_NAME
|
22
|
+
from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH, DESERIALIZING, PKG_NAME
|
22
23
|
from inspect_ai._util.error import EvalError, exception_message
|
24
|
+
from inspect_ai._util.hash import base57_id_hash
|
23
25
|
from inspect_ai._util.logger import warn_once
|
24
26
|
from inspect_ai.approval._policy import ApprovalPolicyConfig
|
25
27
|
from inspect_ai.dataset._dataset import MT, metadata_as
|
@@ -677,6 +679,9 @@ class EvalModelConfig(BaseModel):
|
|
677
679
|
class EvalSpec(BaseModel):
|
678
680
|
"""Eval target and configuration."""
|
679
681
|
|
682
|
+
eval_id: str = Field(default_factory=str)
|
683
|
+
"""Globally unique id for eval."""
|
684
|
+
|
680
685
|
run_id: str = Field(default_factory=str)
|
681
686
|
"""Unique run id"""
|
682
687
|
|
@@ -757,6 +762,21 @@ class EvalSpec(BaseModel):
|
|
757
762
|
# allow field model_args
|
758
763
|
model_config = ConfigDict(protected_namespaces=())
|
759
764
|
|
765
|
+
def model_post_init(self, __context: Any) -> None:
|
766
|
+
# check if deserializing
|
767
|
+
is_deserializing = isinstance(__context, dict) and __context.get(
|
768
|
+
DESERIALIZING, False
|
769
|
+
)
|
770
|
+
|
771
|
+
# Generate eval_id if needed
|
772
|
+
if self.eval_id == "":
|
773
|
+
if is_deserializing:
|
774
|
+
# we want the eval_id to be stable across reads of the eval log so we compose it
|
775
|
+
# as a hash that matches the size/apperance of shortuuid-based uuids
|
776
|
+
self.eval_id = base57_id_hash(self.run_id + self.task_id + self.created)
|
777
|
+
else:
|
778
|
+
self.eval_id = uuid()
|
779
|
+
|
760
780
|
@model_validator(mode="before")
|
761
781
|
@classmethod
|
762
782
|
def read_sandbox_spec(
|
inspect_ai/log/_samples.py
CHANGED
@@ -5,12 +5,11 @@ from typing import AsyncGenerator, Iterator, Literal
|
|
5
5
|
|
6
6
|
from shortuuid import uuid
|
7
7
|
|
8
|
-
from inspect_ai._util.constants import SAMPLE_SUBTASK
|
9
8
|
from inspect_ai.dataset._dataset import Sample
|
10
9
|
from inspect_ai.util._sandbox import SandboxConnection
|
11
10
|
from inspect_ai.util._sandbox.context import sandbox_connections
|
12
11
|
|
13
|
-
from ._transcript import
|
12
|
+
from ._transcript import ModelEvent, Transcript
|
14
13
|
|
15
14
|
|
16
15
|
class ActiveSample:
|
@@ -47,7 +46,6 @@ class ActiveSample:
|
|
47
46
|
self.total_tokens = 0
|
48
47
|
self.transcript = transcript
|
49
48
|
self.sandboxes = sandboxes
|
50
|
-
self.retry_count = 0
|
51
49
|
self._interrupt_action: Literal["score", "error"] | None = None
|
52
50
|
|
53
51
|
@property
|
@@ -151,27 +149,26 @@ def set_active_sample_total_messages(total_messages: int) -> None:
|
|
151
149
|
active.total_messages = total_messages
|
152
150
|
|
153
151
|
|
152
|
+
_active_model_event: ContextVar[ModelEvent | None] = ContextVar(
|
153
|
+
"_active_model_event", default=None
|
154
|
+
)
|
155
|
+
|
156
|
+
|
154
157
|
@contextlib.contextmanager
|
155
|
-
def
|
156
|
-
|
158
|
+
def track_active_model_event(event: ModelEvent) -> Iterator[None]:
|
159
|
+
token = _active_model_event.set(event)
|
157
160
|
try:
|
158
161
|
yield
|
159
162
|
finally:
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
def reset_active_sample_retries() -> None:
|
164
|
-
active = sample_active()
|
165
|
-
if active:
|
166
|
-
active.retry_count = 0
|
163
|
+
_active_model_event.reset(token)
|
167
164
|
|
168
165
|
|
169
166
|
def report_active_sample_retry() -> None:
|
170
|
-
|
171
|
-
if
|
172
|
-
|
173
|
-
|
174
|
-
|
167
|
+
model_event = _active_model_event.get()
|
168
|
+
if model_event is not None:
|
169
|
+
if model_event.retries is None:
|
170
|
+
model_event.retries = 0
|
171
|
+
model_event.retries = model_event.retries + 1
|
175
172
|
|
176
173
|
|
177
174
|
_sample_active: ContextVar[ActiveSample | None] = ContextVar(
|
inspect_ai/log/_transcript.py
CHANGED
@@ -23,9 +23,10 @@ from pydantic import (
|
|
23
23
|
)
|
24
24
|
from shortuuid import uuid
|
25
25
|
|
26
|
-
from inspect_ai._util.constants import
|
26
|
+
from inspect_ai._util.constants import DESERIALIZING
|
27
27
|
from inspect_ai._util.error import EvalError
|
28
|
-
from inspect_ai._util.json import JsonChange
|
28
|
+
from inspect_ai._util.json import JsonChange
|
29
|
+
from inspect_ai._util.logger import warn_once
|
29
30
|
from inspect_ai._util.working import sample_working_time
|
30
31
|
from inspect_ai.dataset._dataset import Sample
|
31
32
|
from inspect_ai.log._message import LoggingMessage
|
@@ -34,7 +35,6 @@ from inspect_ai.model._generate_config import GenerateConfig
|
|
34
35
|
from inspect_ai.model._model_call import ModelCall
|
35
36
|
from inspect_ai.model._model_output import ModelOutput
|
36
37
|
from inspect_ai.scorer._metric import Score
|
37
|
-
from inspect_ai.solver._task_state import state_jsonable
|
38
38
|
from inspect_ai.tool._tool import ToolResult
|
39
39
|
from inspect_ai.tool._tool_call import (
|
40
40
|
ToolCall,
|
@@ -44,6 +44,7 @@ from inspect_ai.tool._tool_call import (
|
|
44
44
|
)
|
45
45
|
from inspect_ai.tool._tool_choice import ToolChoice
|
46
46
|
from inspect_ai.tool._tool_info import ToolInfo
|
47
|
+
from inspect_ai.util._span import current_span_id
|
47
48
|
from inspect_ai.util._store import store, store_changes, store_jsonable
|
48
49
|
|
49
50
|
logger = getLogger(__name__)
|
@@ -57,6 +58,9 @@ class BaseEvent(BaseModel):
|
|
57
58
|
}
|
58
59
|
id_: str = Field(default_factory=lambda: str(uuid()), exclude=True)
|
59
60
|
|
61
|
+
span_id: str | None = Field(default=None)
|
62
|
+
"""Span the event occurred within."""
|
63
|
+
|
60
64
|
timestamp: datetime = Field(default_factory=datetime.now)
|
61
65
|
"""Clock time at which event occurred."""
|
62
66
|
|
@@ -66,6 +70,17 @@ class BaseEvent(BaseModel):
|
|
66
70
|
pending: bool | None = Field(default=None)
|
67
71
|
"""Is this event pending?"""
|
68
72
|
|
73
|
+
def model_post_init(self, __context: Any) -> None:
|
74
|
+
# check if deserializing
|
75
|
+
is_deserializing = isinstance(__context, dict) and __context.get(
|
76
|
+
DESERIALIZING, False
|
77
|
+
)
|
78
|
+
|
79
|
+
# Generate context id fields if not deserializing
|
80
|
+
if not is_deserializing:
|
81
|
+
if self.span_id is None:
|
82
|
+
self.span_id = current_span_id()
|
83
|
+
|
69
84
|
@field_serializer("timestamp")
|
70
85
|
def serialize_timestamp(self, dt: datetime) -> str:
|
71
86
|
return dt.astimezone().isoformat()
|
@@ -147,6 +162,9 @@ class ModelEvent(BaseEvent):
|
|
147
162
|
output: ModelOutput
|
148
163
|
"""Output from model."""
|
149
164
|
|
165
|
+
retries: int | None = Field(default=None)
|
166
|
+
"""Retries for the model API request."""
|
167
|
+
|
150
168
|
error: str | None = Field(default=None)
|
151
169
|
"""Error which occurred during model call."""
|
152
170
|
|
@@ -203,7 +221,13 @@ class ToolEvent(BaseEvent):
|
|
203
221
|
"""Error that occurred during tool call."""
|
204
222
|
|
205
223
|
events: list["Event"] = Field(default_factory=list)
|
206
|
-
"""Transcript of events for tool.
|
224
|
+
"""Transcript of events for tool.
|
225
|
+
|
226
|
+
Note that events are no longer recorded separately within
|
227
|
+
tool events but rather all events are recorded in the main
|
228
|
+
transcript. This field is deprecated and here for backwards
|
229
|
+
compatibility with transcripts that have sub-events.
|
230
|
+
"""
|
207
231
|
|
208
232
|
completed: datetime | None = Field(default=None)
|
209
233
|
"""Time that tool call completed (see `timestamp` for started)"""
|
@@ -222,7 +246,6 @@ class ToolEvent(BaseEvent):
|
|
222
246
|
result: ToolResult,
|
223
247
|
truncated: tuple[int, int] | None,
|
224
248
|
error: ToolCallError | None,
|
225
|
-
events: list["Event"],
|
226
249
|
waiting_time: float,
|
227
250
|
agent: str | None,
|
228
251
|
failed: bool | None,
|
@@ -230,7 +253,6 @@ class ToolEvent(BaseEvent):
|
|
230
253
|
self.result = result
|
231
254
|
self.truncated = truncated
|
232
255
|
self.error = error
|
233
|
-
self.events = events
|
234
256
|
self.pending = None
|
235
257
|
completed = datetime.now()
|
236
258
|
self.completed = completed
|
@@ -402,6 +424,35 @@ class ScoreEvent(BaseEvent):
|
|
402
424
|
"""Was this an intermediate scoring?"""
|
403
425
|
|
404
426
|
|
427
|
+
class SpanBeginEvent(BaseEvent):
|
428
|
+
"""Mark the beginning of a transcript span."""
|
429
|
+
|
430
|
+
event: Literal["span_begin"] = Field(default="span_begin")
|
431
|
+
"""Event type."""
|
432
|
+
|
433
|
+
id: str
|
434
|
+
"""Unique identifier for span."""
|
435
|
+
|
436
|
+
parent_id: str | None = Field(default=None)
|
437
|
+
"""Identifier for parent span."""
|
438
|
+
|
439
|
+
type: str | None = Field(default=None)
|
440
|
+
"""Optional 'type' field for span."""
|
441
|
+
|
442
|
+
name: str
|
443
|
+
"""Span name."""
|
444
|
+
|
445
|
+
|
446
|
+
class SpanEndEvent(BaseEvent):
|
447
|
+
"""Mark the end of a transcript span."""
|
448
|
+
|
449
|
+
event: Literal["span_end"] = Field(default="span_end")
|
450
|
+
"""Event type."""
|
451
|
+
|
452
|
+
id: str
|
453
|
+
"""Unique identifier for span."""
|
454
|
+
|
455
|
+
|
405
456
|
class StepEvent(BaseEvent):
|
406
457
|
"""Step within current sample or subtask."""
|
407
458
|
|
@@ -437,7 +488,13 @@ class SubtaskEvent(BaseEvent):
|
|
437
488
|
"""Subtask function result."""
|
438
489
|
|
439
490
|
events: list["Event"] = Field(default_factory=list)
|
440
|
-
"""Transcript of events for subtask.
|
491
|
+
"""Transcript of events for subtask.
|
492
|
+
|
493
|
+
Note that events are no longer recorded separately within
|
494
|
+
subtasks but rather all events are recorded in the main
|
495
|
+
transcript. This field is deprecated and here for backwards
|
496
|
+
compatibility with transcripts that have sub-events.
|
497
|
+
"""
|
441
498
|
|
442
499
|
completed: datetime | None = Field(default=None)
|
443
500
|
"""Time that subtask completed (see `timestamp` for started)"""
|
@@ -467,6 +524,8 @@ Event: TypeAlias = Union[
|
|
467
524
|
| ErrorEvent
|
468
525
|
| LoggerEvent
|
469
526
|
| InfoEvent
|
527
|
+
| SpanBeginEvent
|
528
|
+
| SpanEndEvent
|
470
529
|
| StepEvent
|
471
530
|
| SubtaskEvent,
|
472
531
|
]
|
@@ -480,8 +539,7 @@ class Transcript:
|
|
480
539
|
|
481
540
|
_event_logger: Callable[[Event], None] | None
|
482
541
|
|
483
|
-
def __init__(self
|
484
|
-
self.name = name
|
542
|
+
def __init__(self) -> None:
|
485
543
|
self._event_logger = None
|
486
544
|
self._events: list[Event] = []
|
487
545
|
|
@@ -498,19 +556,20 @@ class Transcript:
|
|
498
556
|
def step(self, name: str, type: str | None = None) -> Iterator[None]:
|
499
557
|
"""Context manager for recording StepEvent.
|
500
558
|
|
559
|
+
The `step()` context manager is deprecated and will be removed in a future version.
|
560
|
+
Please use the `span()` context manager instead.
|
561
|
+
|
501
562
|
Args:
|
502
563
|
name (str): Step name.
|
503
564
|
type (str | None): Optional step type.
|
504
565
|
"""
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
# end step event
|
513
|
-
self._event(StepEvent(action="end", name=name, type=type))
|
566
|
+
warn_once(
|
567
|
+
logger,
|
568
|
+
"The `transcript().step()` context manager is deprecated and will "
|
569
|
+
+ "be removed in a future version. Please replace the call to step() "
|
570
|
+
+ "with a call to span().",
|
571
|
+
)
|
572
|
+
yield
|
514
573
|
|
515
574
|
@property
|
516
575
|
def events(self) -> Sequence[Event]:
|
@@ -551,23 +610,6 @@ def track_store_changes() -> Iterator[None]:
|
|
551
610
|
transcript()._event(StoreEvent(changes=changes))
|
552
611
|
|
553
612
|
|
554
|
-
@contextlib.contextmanager
|
555
|
-
def track_state_changes(type: str | None = None) -> Iterator[None]:
|
556
|
-
# we only want to track for step() inside the the sample
|
557
|
-
# (solver level tracking is handled already and there are
|
558
|
-
# no state changes in subtasks)
|
559
|
-
if transcript().name == SAMPLE_SUBTASK and type != "solver":
|
560
|
-
before = state_jsonable()
|
561
|
-
yield
|
562
|
-
after = state_jsonable()
|
563
|
-
|
564
|
-
changes = json_changes(before, after)
|
565
|
-
if changes:
|
566
|
-
transcript()._event(StateEvent(changes=changes))
|
567
|
-
else:
|
568
|
-
yield
|
569
|
-
|
570
|
-
|
571
613
|
def init_transcript(transcript: Transcript) -> None:
|
572
614
|
_transcript.set(transcript)
|
573
615
|
|