inspect-ai 0.3.93__py3-none-any.whl → 0.3.95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. inspect_ai/_display/textual/widgets/samples.py +3 -3
  2. inspect_ai/_display/textual/widgets/transcript.py +3 -29
  3. inspect_ai/_eval/loader.py +1 -1
  4. inspect_ai/_eval/task/run.py +21 -12
  5. inspect_ai/_util/answer.py +26 -0
  6. inspect_ai/_util/constants.py +0 -1
  7. inspect_ai/_util/exception.py +4 -0
  8. inspect_ai/_util/hash.py +39 -0
  9. inspect_ai/_util/local_server.py +51 -21
  10. inspect_ai/_util/path.py +22 -0
  11. inspect_ai/_util/trace.py +1 -1
  12. inspect_ai/_util/working.py +4 -0
  13. inspect_ai/_view/www/dist/assets/index.css +23 -22
  14. inspect_ai/_view/www/dist/assets/index.js +517 -204
  15. inspect_ai/_view/www/log-schema.json +375 -0
  16. inspect_ai/_view/www/package.json +1 -1
  17. inspect_ai/_view/www/src/@types/log.d.ts +90 -12
  18. inspect_ai/_view/www/src/app/log-view/navbar/SecondaryBar.tsx +2 -2
  19. inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +1 -4
  20. inspect_ai/_view/www/src/app/samples/SamplesTools.tsx +3 -13
  21. inspect_ai/_view/www/src/app/samples/sample-tools/SelectScorer.tsx +45 -48
  22. inspect_ai/_view/www/src/app/samples/sample-tools/filters.ts +16 -15
  23. inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/SampleFilter.tsx +47 -75
  24. inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/completions.ts +9 -9
  25. inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
  26. inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
  27. inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
  28. inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
  29. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
  30. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
  31. inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
  32. inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
  33. inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
  34. inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
  35. inspect_ai/_view/www/src/app/types.ts +12 -2
  36. inspect_ai/_view/www/src/components/ExpandablePanel.module.css +1 -1
  37. inspect_ai/_view/www/src/components/ExpandablePanel.tsx +5 -5
  38. inspect_ai/_view/www/src/state/hooks.ts +19 -3
  39. inspect_ai/_view/www/src/state/logSlice.ts +23 -5
  40. inspect_ai/_view/www/yarn.lock +9 -9
  41. inspect_ai/agent/_as_solver.py +3 -1
  42. inspect_ai/agent/_as_tool.py +6 -4
  43. inspect_ai/agent/_bridge/patch.py +1 -3
  44. inspect_ai/agent/_handoff.py +5 -1
  45. inspect_ai/agent/_react.py +4 -3
  46. inspect_ai/agent/_run.py +6 -1
  47. inspect_ai/agent/_types.py +9 -0
  48. inspect_ai/analysis/__init__.py +0 -0
  49. inspect_ai/analysis/beta/__init__.py +57 -0
  50. inspect_ai/analysis/beta/_dataframe/__init__.py +0 -0
  51. inspect_ai/analysis/beta/_dataframe/columns.py +145 -0
  52. inspect_ai/analysis/beta/_dataframe/evals/__init__.py +0 -0
  53. inspect_ai/analysis/beta/_dataframe/evals/columns.py +132 -0
  54. inspect_ai/analysis/beta/_dataframe/evals/extract.py +23 -0
  55. inspect_ai/analysis/beta/_dataframe/evals/table.py +140 -0
  56. inspect_ai/analysis/beta/_dataframe/events/__init__.py +0 -0
  57. inspect_ai/analysis/beta/_dataframe/events/columns.py +37 -0
  58. inspect_ai/analysis/beta/_dataframe/events/table.py +14 -0
  59. inspect_ai/analysis/beta/_dataframe/extract.py +54 -0
  60. inspect_ai/analysis/beta/_dataframe/messages/__init__.py +0 -0
  61. inspect_ai/analysis/beta/_dataframe/messages/columns.py +60 -0
  62. inspect_ai/analysis/beta/_dataframe/messages/extract.py +21 -0
  63. inspect_ai/analysis/beta/_dataframe/messages/table.py +87 -0
  64. inspect_ai/analysis/beta/_dataframe/record.py +377 -0
  65. inspect_ai/analysis/beta/_dataframe/samples/__init__.py +0 -0
  66. inspect_ai/analysis/beta/_dataframe/samples/columns.py +73 -0
  67. inspect_ai/analysis/beta/_dataframe/samples/extract.py +82 -0
  68. inspect_ai/analysis/beta/_dataframe/samples/table.py +329 -0
  69. inspect_ai/analysis/beta/_dataframe/util.py +157 -0
  70. inspect_ai/analysis/beta/_dataframe/validate.py +171 -0
  71. inspect_ai/dataset/_dataset.py +6 -3
  72. inspect_ai/log/__init__.py +10 -0
  73. inspect_ai/log/_convert.py +4 -9
  74. inspect_ai/log/_file.py +1 -1
  75. inspect_ai/log/_log.py +21 -1
  76. inspect_ai/log/_samples.py +14 -17
  77. inspect_ai/log/_transcript.py +77 -35
  78. inspect_ai/log/_tree.py +118 -0
  79. inspect_ai/model/_call_tools.py +44 -35
  80. inspect_ai/model/_model.py +51 -44
  81. inspect_ai/model/_openai_responses.py +17 -18
  82. inspect_ai/model/_providers/anthropic.py +30 -5
  83. inspect_ai/model/_providers/hf.py +27 -1
  84. inspect_ai/model/_providers/providers.py +1 -1
  85. inspect_ai/model/_providers/sglang.py +8 -2
  86. inspect_ai/model/_providers/vllm.py +6 -2
  87. inspect_ai/scorer/_choice.py +1 -2
  88. inspect_ai/solver/_chain.py +1 -1
  89. inspect_ai/solver/_fork.py +1 -1
  90. inspect_ai/solver/_multiple_choice.py +9 -23
  91. inspect_ai/solver/_plan.py +2 -2
  92. inspect_ai/solver/_task_state.py +7 -3
  93. inspect_ai/solver/_transcript.py +6 -7
  94. inspect_ai/tool/_mcp/_context.py +3 -5
  95. inspect_ai/tool/_mcp/_mcp.py +6 -5
  96. inspect_ai/tool/_mcp/server.py +1 -1
  97. inspect_ai/tool/_tools/_execute.py +4 -1
  98. inspect_ai/tool/_tools/_think.py +1 -1
  99. inspect_ai/tool/_tools/_web_search/__init__.py +3 -0
  100. inspect_ai/tool/_tools/{_web_search.py → _web_search/_google.py} +56 -103
  101. inspect_ai/tool/_tools/_web_search/_tavily.py +77 -0
  102. inspect_ai/tool/_tools/_web_search/_web_search.py +85 -0
  103. inspect_ai/util/__init__.py +4 -0
  104. inspect_ai/util/_anyio.py +11 -0
  105. inspect_ai/util/_collect.py +50 -0
  106. inspect_ai/util/_sandbox/events.py +3 -2
  107. inspect_ai/util/_span.py +58 -0
  108. inspect_ai/util/_subtask.py +27 -42
  109. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/METADATA +8 -1
  110. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/RECORD +114 -82
  111. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/WHEEL +1 -1
  112. inspect_ai/_display/core/group.py +0 -79
  113. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/entry_points.txt +0 -0
  114. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/licenses/LICENSE +0 -0
  115. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,171 @@
1
+ from __future__ import annotations
2
+
3
+ from logging import getLogger
4
+ from typing import Any, Iterator, Mapping, Type
5
+
6
+ import jsonref # type: ignore
7
+ from jsonpath_ng import Fields, Index, JSONPath, Slice, Where, WhereNot # type: ignore
8
+ from jsonpath_ng.ext.filter import Filter # type: ignore
9
+ from pydantic import BaseModel
10
+
11
+ logger = getLogger(__name__)
12
+
13
+ Schema = Mapping[str, Any]
14
+
15
+
16
+ def resolved_schema(model: Type[BaseModel]) -> Schema:
17
+ schema_dict = model.model_json_schema()
18
+ base = "file:///memory/inspect_schema.json"
19
+ schema: Schema = jsonref.replace_refs(
20
+ schema_dict, base_uri=base, jsonschema=True, proxies=False
21
+ )
22
+ return schema
23
+
24
+
25
+ def jsonpath_in_schema(expr: JSONPath, schema: Schema) -> bool:
26
+ # don't validate unsupported constructs
27
+ if find_unsupported(expr):
28
+ return True
29
+
30
+ def descend(sch: Schema, tok: str | int | None) -> list[Schema]:
31
+ # First, branch through anyOf/oneOf/allOf
32
+ outs: list[Schema] = []
33
+ for branch in _expand_union(sch):
34
+ outs.extend(descend_concrete(branch, tok))
35
+ return outs
36
+
37
+ def descend_concrete(sch: Schema, tok: str | int | None) -> list[Schema]:
38
+ # totally open object – accept any child
39
+ if sch == {}:
40
+ return [{}] # stay alive, accept any key
41
+
42
+ outs: list[Schema] = []
43
+
44
+ def open_dict(node: Schema) -> None:
45
+ """Append the schema that governs unknown keys.
46
+
47
+ - None / missing -> open object -> {}
48
+ - True -> open object -> {}
49
+ - Mapping -> that mapping (could be {} or a real subschema)
50
+ - False -> closed object -> (do nothing)
51
+ """
52
+ if "additionalProperties" not in node:
53
+ if not node.get("properties"):
54
+ outs.append({})
55
+ else:
56
+ ap = node["additionalProperties"]
57
+ if ap is True:
58
+ outs.append({})
59
+ elif isinstance(ap, Mapping): # {} or {...}
60
+ outs.append(ap)
61
+ # ap is False -> closed dict -> ignore
62
+
63
+ # Wildcard -----------------------------------------------------------
64
+ if tok is None:
65
+ if "properties" in sch:
66
+ outs.extend(sch["properties"].values())
67
+ if "object" in _types(sch):
68
+ open_dict(sch)
69
+ if "array" in _types(sch) and "items" in sch:
70
+ outs.extend(_normalize_items(sch["items"]))
71
+ return outs
72
+
73
+ # Property access ----------------------------------------------------
74
+ if isinstance(tok, str):
75
+ if "properties" in sch and tok in sch["properties"]:
76
+ outs.append(sch["properties"][tok])
77
+ elif "additionalProperties" in sch: # PRESENCE, not truthiness
78
+ open_dict(sch)
79
+ elif "object" in _types(sch):
80
+ open_dict(sch)
81
+
82
+ # Array index --------------------------------------------------------
83
+ else: # tok is int or None from an Index node
84
+ if "array" in _types(sch) and "items" in sch:
85
+ outs.extend(_normalize_items(sch["items"], index=tok))
86
+
87
+ return outs
88
+
89
+ def _types(sch: Schema) -> set[str]:
90
+ t = sch.get("type")
91
+ return set(t) if isinstance(t, list) else {t} if t else set()
92
+
93
+ def _normalize_items(items: Any, index: int | None = None) -> list[Schema]:
94
+ if isinstance(items, list):
95
+ if index is None: # wildcard/slice
96
+ return items
97
+ if 0 <= index < len(items):
98
+ return [items[index]]
99
+ return []
100
+ if isinstance(items, Mapping):
101
+ return [items]
102
+ return []
103
+
104
+ states = [schema]
105
+ for tok in iter_tokens(expr):
106
+ next_states: list[Schema] = []
107
+ for st in states:
108
+ next_states.extend(descend(st, tok))
109
+ if not next_states: # nothing matched this segment
110
+ return False
111
+ states = next_states
112
+ return True # every segment found at least one schema
113
+
114
+
115
+ def iter_tokens(node: JSONPath) -> Iterator[str | int | None]:
116
+ """Linearise a jsonpath-ng AST into a stream of tokens we care about."""
117
+ if hasattr(node, "left"): # Child, Descendants, etc.
118
+ yield from iter_tokens(node.left)
119
+ yield from iter_tokens(node.right)
120
+ elif isinstance(node, Fields):
121
+ yield from node.fields # e.g. ["foo"]
122
+ elif isinstance(node, Index):
123
+ yield node.index # 0 / -1 / None for wildcard
124
+ elif isinstance(node, Slice):
125
+ yield None # treat any slice as wildcard
126
+
127
+
128
+ COMBINATORS = ("anyOf", "oneOf", "allOf")
129
+
130
+
131
+ def _expand_union(sch: Schema) -> list[Schema]:
132
+ """Return sch itself or the list of subschemas if it is a combinator."""
133
+ for key in COMBINATORS:
134
+ if key in sch:
135
+ subs: list[Schema] = []
136
+ for sub in sch[key]:
137
+ # a sub-schema might itself be an anyOf/oneOf/allOf
138
+ subs.extend(_expand_union(sub))
139
+ return subs
140
+ return [sch]
141
+
142
+
143
+ UNSUPPORTED: tuple[type[JSONPath], ...] = (
144
+ Filter, # [?foo > 0]
145
+ Where, # .foo[(@.bar < 42)]
146
+ WhereNot,
147
+ Slice, # [1:5] (wildcard “[*]” is Index/None, not Slice)
148
+ )
149
+
150
+
151
+ def find_unsupported(node: JSONPath) -> list[type[JSONPath]]:
152
+ """Return a list of node types present in `node` that we do not validate."""
153
+ bad: list[type[JSONPath]] = []
154
+ stack: list[JSONPath] = [node]
155
+ while stack:
156
+ n = stack.pop()
157
+ if isinstance(n, UNSUPPORTED):
158
+ bad.append(type(n))
159
+ # Drill into children (jsonpath-ng uses .left / .right / .child attributes)
160
+ for attr in ("left", "right", "child", "expression"):
161
+ stack.extend(
162
+ [getattr(n, attr)]
163
+ if hasattr(n, attr) and isinstance(getattr(n, attr), JSONPath)
164
+ else []
165
+ )
166
+ # handle containers like Fields(fields=[...]) and Index(index=[...])
167
+ if hasattr(n, "__dict__"):
168
+ for v in n.__dict__.values():
169
+ if isinstance(v, list):
170
+ stack.extend(x for x in v if isinstance(x, JSONPath))
171
+ return bad
@@ -16,6 +16,7 @@ from typing import (
16
16
  from pydantic import BaseModel, Field, ValidationError
17
17
  from typing_extensions import override
18
18
 
19
+ from inspect_ai._util.answer import answer_character, answer_index
19
20
  from inspect_ai.model import ChatMessage
20
21
  from inspect_ai.util import SandboxEnvironmentSpec, SandboxEnvironmentType
21
22
  from inspect_ai.util._sandbox.environment import resolve_sandbox_environment
@@ -328,7 +329,9 @@ class MemoryDataset(Dataset):
328
329
  shuffled_choices = [sample.choices[i] for i in positions]
329
330
 
330
331
  # Map of original position / target letter
331
- position_map = {i: chr(65 + new_i) for new_i, i in enumerate(positions)}
332
+ position_map = {
333
+ i: answer_character(new_i) for new_i, i in enumerate(positions)
334
+ }
332
335
 
333
336
  # Update to the shuffled choices and target
334
337
  sample.choices = shuffled_choices
@@ -338,9 +341,9 @@ class MemoryDataset(Dataset):
338
341
  self, target: str | list[str], position_map: dict[int, str]
339
342
  ) -> str | list[str]:
340
343
  if isinstance(target, list):
341
- return [position_map[ord(t) - 65] for t in target]
344
+ return [position_map[answer_index(t)] for t in target]
342
345
  else:
343
- return position_map[ord(target) - 65]
346
+ return position_map[answer_index(target)]
344
347
 
345
348
  @override
346
349
  def sort(
@@ -48,6 +48,8 @@ from ._transcript import (
48
48
  SampleLimitEvent,
49
49
  SandboxEvent,
50
50
  ScoreEvent,
51
+ SpanBeginEvent,
52
+ SpanEndEvent,
51
53
  StateEvent,
52
54
  StepEvent,
53
55
  StoreEvent,
@@ -56,6 +58,7 @@ from ._transcript import (
56
58
  Transcript,
57
59
  transcript,
58
60
  )
61
+ from ._tree import EventNode, EventTree, SpanNode, event_sequence, event_tree
59
62
 
60
63
  __all__ = [
61
64
  "EvalConfig",
@@ -92,6 +95,8 @@ __all__ = [
92
95
  "SampleLimitEvent",
93
96
  "SandboxEvent",
94
97
  "ScoreEvent",
98
+ "SpanBeginEvent",
99
+ "SpanEndEvent",
95
100
  "StateEvent",
96
101
  "StepEvent",
97
102
  "StoreEvent",
@@ -111,4 +116,9 @@ __all__ = [
111
116
  "write_log_dir_manifest",
112
117
  "retryable_eval_logs",
113
118
  "bundle_log_dir",
119
+ "event_tree",
120
+ "event_sequence",
121
+ "EventTree",
122
+ "EventNode",
123
+ "SpanNode",
114
124
  ]
@@ -2,7 +2,7 @@ import os
2
2
  from typing import Literal
3
3
 
4
4
  from inspect_ai._util.error import PrerequisiteError
5
- from inspect_ai._util.file import copy_file, exists, filesystem
5
+ from inspect_ai._util.file import exists, filesystem
6
6
  from inspect_ai.log._file import (
7
7
  log_files_from_ls,
8
8
  read_eval_log,
@@ -66,14 +66,9 @@ def convert_eval_logs(
66
66
  "Output file {output_file} already exists (use --overwrite to overwrite existing files)"
67
67
  )
68
68
 
69
- # if the input and output files have the same format just copy
70
- if input_file.endswith(f".{to}"):
71
- copy_file(input_file, output_file)
72
-
73
- # otherwise do a full read/write
74
- else:
75
- log = read_eval_log(input_file)
76
- write_eval_log(log, output_file)
69
+ # do a full read/write (normalized deprecated constructs and adds sample summaries)
70
+ log = read_eval_log(input_file)
71
+ write_eval_log(log, output_file)
77
72
 
78
73
  if fs.info(path).type == "file":
79
74
  convert_file(path)
inspect_ai/log/_file.py CHANGED
@@ -524,7 +524,7 @@ def manifest_eval_log_name(info: EvalLogInfo, log_dir: str, sep: str) -> str:
524
524
 
525
525
  def log_files_from_ls(
526
526
  ls: list[FileInfo],
527
- formats: list[Literal["eval", "json"]] | None,
527
+ formats: list[Literal["eval", "json"]] | None = None,
528
528
  descending: bool = True,
529
529
  ) -> list[EvalLogInfo]:
530
530
  extensions = [f".{format}" for format in (formats or ALL_LOG_FORMATS)]
inspect_ai/log/_log.py CHANGED
@@ -17,9 +17,11 @@ from pydantic import (
17
17
  )
18
18
  from rich.console import Console, RenderableType
19
19
  from rich.traceback import Traceback
20
+ from shortuuid import uuid
20
21
 
21
- from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH, PKG_NAME
22
+ from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH, DESERIALIZING, PKG_NAME
22
23
  from inspect_ai._util.error import EvalError, exception_message
24
+ from inspect_ai._util.hash import base57_id_hash
23
25
  from inspect_ai._util.logger import warn_once
24
26
  from inspect_ai.approval._policy import ApprovalPolicyConfig
25
27
  from inspect_ai.dataset._dataset import MT, metadata_as
@@ -677,6 +679,9 @@ class EvalModelConfig(BaseModel):
677
679
  class EvalSpec(BaseModel):
678
680
  """Eval target and configuration."""
679
681
 
682
+ eval_id: str = Field(default_factory=str)
683
+ """Globally unique id for eval."""
684
+
680
685
  run_id: str = Field(default_factory=str)
681
686
  """Unique run id"""
682
687
 
@@ -757,6 +762,21 @@ class EvalSpec(BaseModel):
757
762
  # allow field model_args
758
763
  model_config = ConfigDict(protected_namespaces=())
759
764
 
765
+ def model_post_init(self, __context: Any) -> None:
766
+ # check if deserializing
767
+ is_deserializing = isinstance(__context, dict) and __context.get(
768
+ DESERIALIZING, False
769
+ )
770
+
771
+ # Generate eval_id if needed
772
+ if self.eval_id == "":
773
+ if is_deserializing:
774
+ # we want the eval_id to be stable across reads of the eval log so we compose it
775
+ # as a hash that matches the size/apperance of shortuuid-based uuids
776
+ self.eval_id = base57_id_hash(self.run_id + self.task_id + self.created)
777
+ else:
778
+ self.eval_id = uuid()
779
+
760
780
  @model_validator(mode="before")
761
781
  @classmethod
762
782
  def read_sandbox_spec(
@@ -5,12 +5,11 @@ from typing import AsyncGenerator, Iterator, Literal
5
5
 
6
6
  from shortuuid import uuid
7
7
 
8
- from inspect_ai._util.constants import SAMPLE_SUBTASK
9
8
  from inspect_ai.dataset._dataset import Sample
10
9
  from inspect_ai.util._sandbox import SandboxConnection
11
10
  from inspect_ai.util._sandbox.context import sandbox_connections
12
11
 
13
- from ._transcript import Transcript, transcript
12
+ from ._transcript import ModelEvent, Transcript
14
13
 
15
14
 
16
15
  class ActiveSample:
@@ -47,7 +46,6 @@ class ActiveSample:
47
46
  self.total_tokens = 0
48
47
  self.transcript = transcript
49
48
  self.sandboxes = sandboxes
50
- self.retry_count = 0
51
49
  self._interrupt_action: Literal["score", "error"] | None = None
52
50
 
53
51
  @property
@@ -151,27 +149,26 @@ def set_active_sample_total_messages(total_messages: int) -> None:
151
149
  active.total_messages = total_messages
152
150
 
153
151
 
152
+ _active_model_event: ContextVar[ModelEvent | None] = ContextVar(
153
+ "_active_model_event", default=None
154
+ )
155
+
156
+
154
157
  @contextlib.contextmanager
155
- def track_active_sample_retries() -> Iterator[None]:
156
- reset_active_sample_retries()
158
+ def track_active_model_event(event: ModelEvent) -> Iterator[None]:
159
+ token = _active_model_event.set(event)
157
160
  try:
158
161
  yield
159
162
  finally:
160
- reset_active_sample_retries()
161
-
162
-
163
- def reset_active_sample_retries() -> None:
164
- active = sample_active()
165
- if active:
166
- active.retry_count = 0
163
+ _active_model_event.reset(token)
167
164
 
168
165
 
169
166
  def report_active_sample_retry() -> None:
170
- active = sample_active()
171
- if active:
172
- # only do this for the top level subtask
173
- if transcript().name == SAMPLE_SUBTASK:
174
- active.retry_count = active.retry_count + 1
167
+ model_event = _active_model_event.get()
168
+ if model_event is not None:
169
+ if model_event.retries is None:
170
+ model_event.retries = 0
171
+ model_event.retries = model_event.retries + 1
175
172
 
176
173
 
177
174
  _sample_active: ContextVar[ActiveSample | None] = ContextVar(
@@ -23,9 +23,10 @@ from pydantic import (
23
23
  )
24
24
  from shortuuid import uuid
25
25
 
26
- from inspect_ai._util.constants import SAMPLE_SUBTASK
26
+ from inspect_ai._util.constants import DESERIALIZING
27
27
  from inspect_ai._util.error import EvalError
28
- from inspect_ai._util.json import JsonChange, json_changes
28
+ from inspect_ai._util.json import JsonChange
29
+ from inspect_ai._util.logger import warn_once
29
30
  from inspect_ai._util.working import sample_working_time
30
31
  from inspect_ai.dataset._dataset import Sample
31
32
  from inspect_ai.log._message import LoggingMessage
@@ -34,7 +35,6 @@ from inspect_ai.model._generate_config import GenerateConfig
34
35
  from inspect_ai.model._model_call import ModelCall
35
36
  from inspect_ai.model._model_output import ModelOutput
36
37
  from inspect_ai.scorer._metric import Score
37
- from inspect_ai.solver._task_state import state_jsonable
38
38
  from inspect_ai.tool._tool import ToolResult
39
39
  from inspect_ai.tool._tool_call import (
40
40
  ToolCall,
@@ -44,6 +44,7 @@ from inspect_ai.tool._tool_call import (
44
44
  )
45
45
  from inspect_ai.tool._tool_choice import ToolChoice
46
46
  from inspect_ai.tool._tool_info import ToolInfo
47
+ from inspect_ai.util._span import current_span_id
47
48
  from inspect_ai.util._store import store, store_changes, store_jsonable
48
49
 
49
50
  logger = getLogger(__name__)
@@ -57,6 +58,9 @@ class BaseEvent(BaseModel):
57
58
  }
58
59
  id_: str = Field(default_factory=lambda: str(uuid()), exclude=True)
59
60
 
61
+ span_id: str | None = Field(default=None)
62
+ """Span the event occurred within."""
63
+
60
64
  timestamp: datetime = Field(default_factory=datetime.now)
61
65
  """Clock time at which event occurred."""
62
66
 
@@ -66,6 +70,17 @@ class BaseEvent(BaseModel):
66
70
  pending: bool | None = Field(default=None)
67
71
  """Is this event pending?"""
68
72
 
73
+ def model_post_init(self, __context: Any) -> None:
74
+ # check if deserializing
75
+ is_deserializing = isinstance(__context, dict) and __context.get(
76
+ DESERIALIZING, False
77
+ )
78
+
79
+ # Generate context id fields if not deserializing
80
+ if not is_deserializing:
81
+ if self.span_id is None:
82
+ self.span_id = current_span_id()
83
+
69
84
  @field_serializer("timestamp")
70
85
  def serialize_timestamp(self, dt: datetime) -> str:
71
86
  return dt.astimezone().isoformat()
@@ -147,6 +162,9 @@ class ModelEvent(BaseEvent):
147
162
  output: ModelOutput
148
163
  """Output from model."""
149
164
 
165
+ retries: int | None = Field(default=None)
166
+ """Retries for the model API request."""
167
+
150
168
  error: str | None = Field(default=None)
151
169
  """Error which occurred during model call."""
152
170
 
@@ -203,7 +221,13 @@ class ToolEvent(BaseEvent):
203
221
  """Error that occurred during tool call."""
204
222
 
205
223
  events: list["Event"] = Field(default_factory=list)
206
- """Transcript of events for tool."""
224
+ """Transcript of events for tool.
225
+
226
+ Note that events are no longer recorded separately within
227
+ tool events but rather all events are recorded in the main
228
+ transcript. This field is deprecated and here for backwards
229
+ compatibility with transcripts that have sub-events.
230
+ """
207
231
 
208
232
  completed: datetime | None = Field(default=None)
209
233
  """Time that tool call completed (see `timestamp` for started)"""
@@ -222,7 +246,6 @@ class ToolEvent(BaseEvent):
222
246
  result: ToolResult,
223
247
  truncated: tuple[int, int] | None,
224
248
  error: ToolCallError | None,
225
- events: list["Event"],
226
249
  waiting_time: float,
227
250
  agent: str | None,
228
251
  failed: bool | None,
@@ -230,7 +253,6 @@ class ToolEvent(BaseEvent):
230
253
  self.result = result
231
254
  self.truncated = truncated
232
255
  self.error = error
233
- self.events = events
234
256
  self.pending = None
235
257
  completed = datetime.now()
236
258
  self.completed = completed
@@ -402,6 +424,35 @@ class ScoreEvent(BaseEvent):
402
424
  """Was this an intermediate scoring?"""
403
425
 
404
426
 
427
+ class SpanBeginEvent(BaseEvent):
428
+ """Mark the beginning of a transcript span."""
429
+
430
+ event: Literal["span_begin"] = Field(default="span_begin")
431
+ """Event type."""
432
+
433
+ id: str
434
+ """Unique identifier for span."""
435
+
436
+ parent_id: str | None = Field(default=None)
437
+ """Identifier for parent span."""
438
+
439
+ type: str | None = Field(default=None)
440
+ """Optional 'type' field for span."""
441
+
442
+ name: str
443
+ """Span name."""
444
+
445
+
446
+ class SpanEndEvent(BaseEvent):
447
+ """Mark the end of a transcript span."""
448
+
449
+ event: Literal["span_end"] = Field(default="span_end")
450
+ """Event type."""
451
+
452
+ id: str
453
+ """Unique identifier for span."""
454
+
455
+
405
456
  class StepEvent(BaseEvent):
406
457
  """Step within current sample or subtask."""
407
458
 
@@ -437,7 +488,13 @@ class SubtaskEvent(BaseEvent):
437
488
  """Subtask function result."""
438
489
 
439
490
  events: list["Event"] = Field(default_factory=list)
440
- """Transcript of events for subtask."""
491
+ """Transcript of events for subtask.
492
+
493
+ Note that events are no longer recorded separately within
494
+ subtasks but rather all events are recorded in the main
495
+ transcript. This field is deprecated and here for backwards
496
+ compatibility with transcripts that have sub-events.
497
+ """
441
498
 
442
499
  completed: datetime | None = Field(default=None)
443
500
  """Time that subtask completed (see `timestamp` for started)"""
@@ -467,6 +524,8 @@ Event: TypeAlias = Union[
467
524
  | ErrorEvent
468
525
  | LoggerEvent
469
526
  | InfoEvent
527
+ | SpanBeginEvent
528
+ | SpanEndEvent
470
529
  | StepEvent
471
530
  | SubtaskEvent,
472
531
  ]
@@ -480,8 +539,7 @@ class Transcript:
480
539
 
481
540
  _event_logger: Callable[[Event], None] | None
482
541
 
483
- def __init__(self, name: str = "") -> None:
484
- self.name = name
542
+ def __init__(self) -> None:
485
543
  self._event_logger = None
486
544
  self._events: list[Event] = []
487
545
 
@@ -498,19 +556,20 @@ class Transcript:
498
556
  def step(self, name: str, type: str | None = None) -> Iterator[None]:
499
557
  """Context manager for recording StepEvent.
500
558
 
559
+ The `step()` context manager is deprecated and will be removed in a future version.
560
+ Please use the `span()` context manager instead.
561
+
501
562
  Args:
502
563
  name (str): Step name.
503
564
  type (str | None): Optional step type.
504
565
  """
505
- # step event
506
- self._event(StepEvent(action="begin", name=name, type=type))
507
-
508
- # run the step (tracking state/store changes)
509
- with track_state_changes(type), track_store_changes():
510
- yield
511
-
512
- # end step event
513
- self._event(StepEvent(action="end", name=name, type=type))
566
+ warn_once(
567
+ logger,
568
+ "The `transcript().step()` context manager is deprecated and will "
569
+ + "be removed in a future version. Please replace the call to step() "
570
+ + "with a call to span().",
571
+ )
572
+ yield
514
573
 
515
574
  @property
516
575
  def events(self) -> Sequence[Event]:
@@ -551,23 +610,6 @@ def track_store_changes() -> Iterator[None]:
551
610
  transcript()._event(StoreEvent(changes=changes))
552
611
 
553
612
 
554
- @contextlib.contextmanager
555
- def track_state_changes(type: str | None = None) -> Iterator[None]:
556
- # we only want to track for step() inside the the sample
557
- # (solver level tracking is handled already and there are
558
- # no state changes in subtasks)
559
- if transcript().name == SAMPLE_SUBTASK and type != "solver":
560
- before = state_jsonable()
561
- yield
562
- after = state_jsonable()
563
-
564
- changes = json_changes(before, after)
565
- if changes:
566
- transcript()._event(StateEvent(changes=changes))
567
- else:
568
- yield
569
-
570
-
571
613
  def init_transcript(transcript: Transcript) -> None:
572
614
  _transcript.set(transcript)
573
615