inspect-ai 0.3.92__py3-none-any.whl → 0.3.94__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. inspect_ai/_cli/eval.py +27 -0
  2. inspect_ai/_display/textual/widgets/samples.py +3 -3
  3. inspect_ai/_display/textual/widgets/transcript.py +3 -29
  4. inspect_ai/_eval/eval.py +19 -2
  5. inspect_ai/_eval/evalset.py +4 -1
  6. inspect_ai/_eval/run.py +41 -0
  7. inspect_ai/_eval/task/generate.py +38 -44
  8. inspect_ai/_eval/task/log.py +26 -28
  9. inspect_ai/_eval/task/run.py +23 -27
  10. inspect_ai/_util/answer.py +26 -0
  11. inspect_ai/_util/constants.py +0 -1
  12. inspect_ai/_util/local_server.py +398 -0
  13. inspect_ai/_util/working.py +10 -4
  14. inspect_ai/_view/www/dist/assets/index.css +173 -159
  15. inspect_ai/_view/www/dist/assets/index.js +1417 -1142
  16. inspect_ai/_view/www/log-schema.json +379 -3
  17. inspect_ai/_view/www/package.json +1 -1
  18. inspect_ai/_view/www/src/@types/log.d.ts +93 -14
  19. inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +2 -2
  20. inspect_ai/_view/www/src/app/content/MetaDataView.module.css +1 -1
  21. inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +1 -1
  22. inspect_ai/_view/www/src/app/content/RenderedContent.tsx +1 -1
  23. inspect_ai/_view/www/src/app/log-view/LogView.tsx +11 -0
  24. inspect_ai/_view/www/src/app/log-view/tabs/InfoTab.tsx +2 -9
  25. inspect_ai/_view/www/src/app/log-view/tabs/ModelsTab.tsx +51 -0
  26. inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.module.css +6 -0
  27. inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.tsx +143 -0
  28. inspect_ai/_view/www/src/app/plan/ModelCard.tsx +1 -2
  29. inspect_ai/_view/www/src/app/plan/PlanCard.tsx +29 -7
  30. inspect_ai/_view/www/src/app/plan/PlanDetailView.module.css +1 -1
  31. inspect_ai/_view/www/src/app/plan/PlanDetailView.tsx +1 -198
  32. inspect_ai/_view/www/src/app/samples/descriptor/score/NumericScoreDescriptor.tsx +2 -1
  33. inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
  34. inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
  35. inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
  36. inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
  37. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
  38. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
  39. inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
  40. inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
  41. inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
  42. inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
  43. inspect_ai/_view/www/src/app/usage/ModelUsagePanel.tsx +3 -2
  44. inspect_ai/_view/www/src/app/usage/TokenTable.module.css +4 -1
  45. inspect_ai/_view/www/src/app/usage/TokenTable.tsx +2 -2
  46. inspect_ai/_view/www/src/app/usage/UsageCard.module.css +8 -3
  47. inspect_ai/_view/www/src/app/usage/UsageCard.tsx +1 -35
  48. inspect_ai/_view/www/src/components/Card.css +0 -1
  49. inspect_ai/_view/www/src/constants.ts +2 -0
  50. inspect_ai/_view/www/src/utils/numeric.ts +17 -0
  51. inspect_ai/agent/_agent.py +3 -3
  52. inspect_ai/agent/_as_solver.py +22 -12
  53. inspect_ai/agent/_as_tool.py +20 -6
  54. inspect_ai/agent/_handoff.py +12 -1
  55. inspect_ai/agent/_react.py +4 -3
  56. inspect_ai/agent/_run.py +16 -3
  57. inspect_ai/agent/_types.py +9 -0
  58. inspect_ai/dataset/_dataset.py +6 -3
  59. inspect_ai/log/__init__.py +14 -0
  60. inspect_ai/log/_convert.py +4 -9
  61. inspect_ai/log/_file.py +56 -0
  62. inspect_ai/log/_log.py +99 -0
  63. inspect_ai/log/_recorders/__init__.py +2 -0
  64. inspect_ai/log/_recorders/buffer/database.py +12 -11
  65. inspect_ai/log/_recorders/buffer/filestore.py +2 -2
  66. inspect_ai/log/_recorders/buffer/types.py +2 -2
  67. inspect_ai/log/_recorders/eval.py +20 -65
  68. inspect_ai/log/_recorders/file.py +28 -6
  69. inspect_ai/log/_recorders/recorder.py +7 -0
  70. inspect_ai/log/_recorders/types.py +1 -23
  71. inspect_ai/log/_samples.py +14 -25
  72. inspect_ai/log/_transcript.py +84 -36
  73. inspect_ai/log/_tree.py +118 -0
  74. inspect_ai/log/_util.py +52 -0
  75. inspect_ai/model/__init__.py +5 -1
  76. inspect_ai/model/_call_tools.py +72 -44
  77. inspect_ai/model/_generate_config.py +14 -8
  78. inspect_ai/model/_model.py +66 -88
  79. inspect_ai/model/_model_output.py +25 -0
  80. inspect_ai/model/_openai.py +2 -0
  81. inspect_ai/model/_providers/anthropic.py +13 -23
  82. inspect_ai/model/_providers/hf.py +27 -1
  83. inspect_ai/model/_providers/openai_o1.py +8 -2
  84. inspect_ai/model/_providers/providers.py +18 -4
  85. inspect_ai/model/_providers/sglang.py +247 -0
  86. inspect_ai/model/_providers/vllm.py +211 -400
  87. inspect_ai/scorer/_choice.py +1 -2
  88. inspect_ai/solver/__init__.py +7 -2
  89. inspect_ai/solver/_basic_agent.py +3 -10
  90. inspect_ai/solver/_chain.py +1 -1
  91. inspect_ai/solver/_fork.py +1 -1
  92. inspect_ai/solver/_multiple_choice.py +5 -22
  93. inspect_ai/solver/_plan.py +2 -2
  94. inspect_ai/solver/_task_state.py +26 -88
  95. inspect_ai/solver/_transcript.py +6 -7
  96. inspect_ai/tool/_json_rpc_helpers.py +45 -17
  97. inspect_ai/tool/_mcp/_mcp.py +8 -5
  98. inspect_ai/tool/_mcp/_sandbox.py +8 -2
  99. inspect_ai/tool/_mcp/server.py +3 -1
  100. inspect_ai/tool/_tool_call.py +4 -1
  101. inspect_ai/tool/_tool_support_helpers.py +51 -12
  102. inspect_ai/tool/_tools/_bash_session.py +190 -68
  103. inspect_ai/tool/_tools/_computer/_computer.py +25 -1
  104. inspect_ai/tool/_tools/_execute.py +4 -1
  105. inspect_ai/tool/_tools/_text_editor.py +4 -3
  106. inspect_ai/tool/_tools/_web_browser/_web_browser.py +10 -3
  107. inspect_ai/util/__init__.py +16 -0
  108. inspect_ai/util/_anyio.py +11 -0
  109. inspect_ai/util/_collect.py +50 -0
  110. inspect_ai/util/_limit.py +393 -0
  111. inspect_ai/util/_limited_conversation.py +57 -0
  112. inspect_ai/util/_span.py +58 -0
  113. inspect_ai/util/_subtask.py +27 -42
  114. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/METADATA +1 -1
  115. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/RECORD +120 -134
  116. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/WHEEL +1 -1
  117. inspect_ai/_display/core/group.py +0 -79
  118. inspect_ai/solver/_limit.py +0 -39
  119. inspect_ai/tool/_tools/_computer/_resources/Dockerfile +0 -102
  120. inspect_ai/tool/_tools/_computer/_resources/README.md +0 -30
  121. inspect_ai/tool/_tools/_computer/_resources/entrypoint/entrypoint.sh +0 -18
  122. inspect_ai/tool/_tools/_computer/_resources/entrypoint/novnc_startup.sh +0 -20
  123. inspect_ai/tool/_tools/_computer/_resources/entrypoint/x11vnc_startup.sh +0 -48
  124. inspect_ai/tool/_tools/_computer/_resources/entrypoint/xfce_startup.sh +0 -13
  125. inspect_ai/tool/_tools/_computer/_resources/entrypoint/xvfb_startup.sh +0 -48
  126. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/globalStorage/state.vscdb +0 -0
  127. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/settings.json +0 -9
  128. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-panel.xml +0 -61
  129. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +0 -10
  130. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfwm4.xml +0 -91
  131. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +0 -10
  132. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Terminal.desktop +0 -10
  133. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +0 -10
  134. inspect_ai/tool/_tools/_computer/_resources/tool/.pylintrc +0 -8
  135. inspect_ai/tool/_tools/_computer/_resources/tool/.vscode/settings.json +0 -12
  136. inspect_ai/tool/_tools/_computer/_resources/tool/_args.py +0 -78
  137. inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +0 -22
  138. inspect_ai/tool/_tools/_computer/_resources/tool/_logger.py +0 -22
  139. inspect_ai/tool/_tools/_computer/_resources/tool/_run.py +0 -42
  140. inspect_ai/tool/_tools/_computer/_resources/tool/_tool_result.py +0 -33
  141. inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +0 -341
  142. inspect_ai/tool/_tools/_computer/_resources/tool/computer_tool.py +0 -141
  143. inspect_ai/tool/_tools/_computer/_resources/tool/pyproject.toml +0 -65
  144. inspect_ai/tool/_tools/_computer/_resources/tool/requirements.txt +0 -0
  145. inspect_ai/tool/_tools/_computer/test_args.py +0 -151
  146. /inspect_ai/{tool/_tools/_computer/_resources/tool/__init__.py → _view/www/src/app/log-view/tabs/ModelsTab.module.css} +0 -0
  147. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/entry_points.txt +0 -0
  148. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/licenses/LICENSE +0 -0
  149. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,7 @@ from inspect_ai._util.file import FileSystem, basename, dirname, file, filesyste
14
14
  from inspect_ai._util.json import to_json_safe, to_json_str_safe
15
15
  from inspect_ai.log._file import read_eval_log
16
16
 
17
- from ..types import SampleSummary
17
+ from ..._log import EvalSampleSummary
18
18
  from .types import SampleBuffer, SampleData, Samples
19
19
 
20
20
  logger = getLogger(__name__)
@@ -33,7 +33,7 @@ class SegmentFile(BaseModel):
33
33
 
34
34
 
35
35
  class SampleManifest(BaseModel):
36
- summary: SampleSummary
36
+ summary: EvalSampleSummary
37
37
  segments: list[int] = Field(default_factory=list)
38
38
 
39
39
 
@@ -5,13 +5,13 @@ from pydantic import BaseModel, JsonValue
5
5
 
6
6
  from inspect_ai._display.core.display import TaskDisplayMetric
7
7
 
8
- from ..types import SampleSummary
8
+ from ..._log import EvalSampleSummary
9
9
 
10
10
  JsonData: TypeAlias = dict[str, JsonValue]
11
11
 
12
12
 
13
13
  class Samples(BaseModel):
14
- samples: list[SampleSummary]
14
+ samples: list[EvalSampleSummary]
15
15
  metrics: list[TaskDisplayMetric]
16
16
  refresh: int
17
17
  etag: str
@@ -11,18 +11,10 @@ from pydantic_core import to_json
11
11
  from typing_extensions import override
12
12
 
13
13
  from inspect_ai._util.constants import DESERIALIZING_CONTEXT, LOG_SCHEMA_VERSION
14
- from inspect_ai._util.content import (
15
- ContentAudio,
16
- ContentImage,
17
- ContentReasoning,
18
- ContentText,
19
- ContentVideo,
20
- )
21
14
  from inspect_ai._util.error import EvalError
22
15
  from inspect_ai._util.file import FileSystem, dirname, file, filesystem
23
16
  from inspect_ai._util.json import jsonable_python
24
17
  from inspect_ai._util.trace import trace_action
25
- from inspect_ai.model._chat_message import ChatMessage
26
18
 
27
19
  from .._log import (
28
20
  EvalLog,
@@ -30,12 +22,12 @@ from .._log import (
30
22
  EvalResults,
31
23
  EvalSample,
32
24
  EvalSampleReductions,
25
+ EvalSampleSummary,
33
26
  EvalSpec,
34
27
  EvalStats,
35
28
  sort_samples,
36
29
  )
37
30
  from .file import FileRecorder
38
- from .types import SampleSummary
39
31
 
40
32
  logger = getLogger(__name__)
41
33
 
@@ -222,6 +214,15 @@ class EvalRecorder(FileRecorder):
222
214
  f"Sample id {id} for epoch {epoch} not found in log {location}"
223
215
  )
224
216
 
217
+ @classmethod
218
+ @override
219
+ async def read_log_sample_summaries(cls, location: str) -> list[EvalSampleSummary]:
220
+ with file(location, "rb") as z:
221
+ with ZipFile(z, mode="r") as zip:
222
+ summary_counter = _read_summary_counter(zip)
223
+ summaries = _read_all_summaries(zip, summary_counter)
224
+ return summaries
225
+
225
226
  @classmethod
226
227
  @override
227
228
  async def write_log(cls, location: str, log: EvalLog) -> None:
@@ -236,36 +237,6 @@ class EvalRecorder(FileRecorder):
236
237
  )
237
238
 
238
239
 
239
- def text_inputs(inputs: str | list[ChatMessage]) -> str | list[ChatMessage]:
240
- # Clean the input of any images
241
- if isinstance(inputs, list):
242
- input: list[ChatMessage] = []
243
- for message in inputs:
244
- if not isinstance(message.content, str):
245
- filtered_content: list[
246
- ContentText
247
- | ContentReasoning
248
- | ContentImage
249
- | ContentAudio
250
- | ContentVideo
251
- ] = []
252
- for content in message.content:
253
- if content.type == "text":
254
- filtered_content.append(content)
255
- else:
256
- filtered_content.append(
257
- ContentText(text=f"({content.type.capitalize()})")
258
- )
259
- message.content = filtered_content
260
- input.append(message)
261
- else:
262
- input.append(message)
263
-
264
- return input
265
- else:
266
- return inputs
267
-
268
-
269
240
  class ZipLogFile:
270
241
  _zip: ZipFile | None
271
242
  _temp_file: BinaryIO
@@ -273,19 +244,20 @@ class ZipLogFile:
273
244
 
274
245
  def __init__(self, file: str) -> None:
275
246
  self._file = file
247
+ self._zip = None
276
248
  self._fs = filesystem(file)
277
249
  self._lock = anyio.Lock()
278
250
  self._temp_file = tempfile.TemporaryFile()
279
251
  self._samples: list[EvalSample] = []
280
252
  self._summary_counter = 0
281
- self._summaries: list[SampleSummary] = []
253
+ self._summaries: list[EvalSampleSummary] = []
282
254
  self._log_start: LogStart | None = None
283
255
 
284
256
  async def init(
285
257
  self,
286
258
  log_start: LogStart | None,
287
259
  summary_counter: int,
288
- summaries: list[SampleSummary],
260
+ summaries: list[EvalSampleSummary],
289
261
  ) -> None:
290
262
  async with self._lock:
291
263
  self._open()
@@ -309,31 +281,14 @@ class ZipLogFile:
309
281
  async def write_buffered_samples(self) -> None:
310
282
  async with self._lock:
311
283
  # Write the buffered samples
312
- summaries: list[SampleSummary] = []
284
+ summaries: list[EvalSampleSummary] = []
313
285
  for sample in self._samples:
314
286
  # Write the sample
315
287
  self._zip_writestr(_sample_filename(sample.id, sample.epoch), sample)
316
288
 
317
289
  # Capture the summary
318
- summaries.append(
319
- SampleSummary(
320
- id=sample.id,
321
- epoch=sample.epoch,
322
- input=text_inputs(sample.input),
323
- target=sample.target,
324
- completed=True,
325
- scores=sample.scores,
326
- error=sample.error.message
327
- if sample.error is not None
328
- else None,
329
- limit=f"{sample.limit.type}"
330
- if sample.limit is not None
331
- else None,
332
- retries=len(sample.error_retries)
333
- if sample.error_retries is not None
334
- else None,
335
- )
336
- )
290
+ summaries.append(sample.summary())
291
+
337
292
  self._samples.clear()
338
293
 
339
294
  # write intermediary summaries and add to master list
@@ -451,12 +406,12 @@ def _read_summary_counter(zip: ZipFile) -> int:
451
406
  return current_count
452
407
 
453
408
 
454
- def _read_all_summaries(zip: ZipFile, count: int) -> list[SampleSummary]:
409
+ def _read_all_summaries(zip: ZipFile, count: int) -> list[EvalSampleSummary]:
455
410
  if SUMMARIES_JSON in zip.namelist():
456
411
  summaries_raw = _read_json(zip, SUMMARIES_JSON)
457
412
  if isinstance(summaries_raw, list):
458
413
  return [
459
- SampleSummary.model_validate(value, context=DESERIALIZING_CONTEXT)
414
+ EvalSampleSummary.model_validate(value, context=DESERIALIZING_CONTEXT)
460
415
  for value in summaries_raw
461
416
  ]
462
417
  else:
@@ -464,7 +419,7 @@ def _read_all_summaries(zip: ZipFile, count: int) -> list[SampleSummary]:
464
419
  f"Expected a list of summaries when reading {SUMMARIES_JSON}"
465
420
  )
466
421
  else:
467
- summaries: list[SampleSummary] = []
422
+ summaries: list[EvalSampleSummary] = []
468
423
  for i in range(1, count):
469
424
  summary_file = _journal_summary_file(i)
470
425
  summary_path = _journal_summary_path(summary_file)
@@ -472,7 +427,7 @@ def _read_all_summaries(zip: ZipFile, count: int) -> list[SampleSummary]:
472
427
  if isinstance(summary, list):
473
428
  summaries.extend(
474
429
  [
475
- SampleSummary.model_validate(
430
+ EvalSampleSummary.model_validate(
476
431
  value, context=DESERIALIZING_CONTEXT
477
432
  )
478
433
  for value in summary
@@ -8,7 +8,7 @@ from inspect_ai._util.constants import MODEL_NONE
8
8
  from inspect_ai._util.file import filesystem
9
9
  from inspect_ai._util.registry import registry_unqualified_name
10
10
 
11
- from .._log import EvalLog, EvalSample, EvalSpec
11
+ from .._log import EvalLog, EvalSample, EvalSampleSummary, EvalSpec
12
12
  from .recorder import Recorder
13
13
 
14
14
  logger = getLogger(__name__)
@@ -40,11 +40,7 @@ class FileRecorder(Recorder):
40
40
  cls, location: str, id: str | int, epoch: int = 1
41
41
  ) -> EvalSample:
42
42
  # establish the log to read from (might be cached)
43
- if cls.__last_read_sample_log and (cls.__last_read_sample_log[0] == "location"):
44
- eval_log = cls.__last_read_sample_log[1]
45
- else:
46
- eval_log = await cls.read_log(location)
47
- cls.__last_read_sample_log = (location, eval_log)
43
+ eval_log = await cls._log_file_maybe_cached(location)
48
44
 
49
45
  # throw if no samples
50
46
  if not eval_log.samples:
@@ -66,6 +62,32 @@ class FileRecorder(Recorder):
66
62
  else:
67
63
  return eval_sample
68
64
 
65
+ @classmethod
66
+ @override
67
+ async def read_log_sample_summaries(cls, location: str) -> list[EvalSampleSummary]:
68
+ # establish the log to read from (might be cached)
69
+ eval_log = await cls._log_file_maybe_cached(location)
70
+
71
+ # throw if no samples
72
+ if not eval_log.samples:
73
+ raise IndexError(f"No samples found in log {location}")
74
+
75
+ summaries: list[EvalSampleSummary] = []
76
+ for sample in eval_log.samples:
77
+ summaries.append(sample.summary())
78
+
79
+ return summaries
80
+
81
+ @classmethod
82
+ async def _log_file_maybe_cached(cls, location: str) -> EvalLog:
83
+ # establish the log to read from (might be cached)
84
+ if cls.__last_read_sample_log and (cls.__last_read_sample_log[0] == "location"):
85
+ eval_log = cls.__last_read_sample_log[1]
86
+ else:
87
+ eval_log = await cls.read_log(location)
88
+ cls.__last_read_sample_log = (location, eval_log)
89
+ return eval_log
90
+
69
91
  def _log_file_key(self, eval: EvalSpec) -> str:
70
92
  # clean underscores, slashes, and : from the log file key (so we can reliably parse it
71
93
  # later without worrying about underscores)
@@ -8,6 +8,7 @@ from inspect_ai.log._log import (
8
8
  EvalResults,
9
9
  EvalSample,
10
10
  EvalSampleReductions,
11
+ EvalSampleSummary,
11
12
  EvalSpec,
12
13
  EvalStats,
13
14
  )
@@ -57,6 +58,12 @@ class Recorder(abc.ABC):
57
58
  cls, location: str, id: str | int, epoch: int = 1
58
59
  ) -> EvalSample: ...
59
60
 
61
+ @classmethod
62
+ @abc.abstractmethod
63
+ async def read_log_sample_summaries(
64
+ cls, location: str
65
+ ) -> list[EvalSampleSummary]: ...
66
+
60
67
  @classmethod
61
68
  @abc.abstractmethod
62
69
  async def write_log(cls, location: str, log: EvalLog) -> None: ...
@@ -1,31 +1,9 @@
1
- from pydantic import BaseModel, Field, model_validator
1
+ from pydantic import BaseModel
2
2
 
3
3
  from inspect_ai.log._transcript import Event
4
- from inspect_ai.model._chat_message import ChatMessage
5
- from inspect_ai.scorer._metric import Score
6
4
 
7
5
 
8
6
  class SampleEvent(BaseModel):
9
7
  id: str | int
10
8
  epoch: int
11
9
  event: Event
12
-
13
-
14
- class SampleSummary(BaseModel):
15
- id: int | str
16
- epoch: int
17
- input: str | list[ChatMessage]
18
- target: str | list[str]
19
- completed: bool = Field(default=False)
20
- scores: dict[str, Score] | None = Field(default=None)
21
- error: str | None = Field(default=None)
22
- limit: str | None = Field(default=None)
23
- retries: int | None = Field(default=None)
24
-
25
- @model_validator(mode="after")
26
- def thin_scores(self) -> "SampleSummary":
27
- if self.scores is not None:
28
- self.scores = {
29
- key: Score(value=score.value) for key, score in self.scores.items()
30
- }
31
- return self
@@ -5,12 +5,11 @@ from typing import AsyncGenerator, Iterator, Literal
5
5
 
6
6
  from shortuuid import uuid
7
7
 
8
- from inspect_ai._util.constants import SAMPLE_SUBTASK
9
8
  from inspect_ai.dataset._dataset import Sample
10
9
  from inspect_ai.util._sandbox import SandboxConnection
11
10
  from inspect_ai.util._sandbox.context import sandbox_connections
12
11
 
13
- from ._transcript import Transcript, transcript
12
+ from ._transcript import ModelEvent, Transcript
14
13
 
15
14
 
16
15
  class ActiveSample:
@@ -47,7 +46,6 @@ class ActiveSample:
47
46
  self.total_tokens = 0
48
47
  self.transcript = transcript
49
48
  self.sandboxes = sandboxes
50
- self.retry_count = 0
51
49
  self._interrupt_action: Literal["score", "error"] | None = None
52
50
 
53
51
  @property
@@ -119,14 +117,6 @@ def sample_active() -> ActiveSample | None:
119
117
  return _sample_active.get(None)
120
118
 
121
119
 
122
- def active_sample_token_limit() -> int | None:
123
- active = sample_active()
124
- if active:
125
- return active.token_limit
126
- else:
127
- return None
128
-
129
-
130
120
  def set_active_sample_token_limit(token_limit: int | None) -> None:
131
121
  active = sample_active()
132
122
  if active:
@@ -159,27 +149,26 @@ def set_active_sample_total_messages(total_messages: int) -> None:
159
149
  active.total_messages = total_messages
160
150
 
161
151
 
152
+ _active_model_event: ContextVar[ModelEvent | None] = ContextVar(
153
+ "_active_model_event", default=None
154
+ )
155
+
156
+
162
157
  @contextlib.contextmanager
163
- def track_active_sample_retries() -> Iterator[None]:
164
- reset_active_sample_retries()
158
+ def track_active_model_event(event: ModelEvent) -> Iterator[None]:
159
+ token = _active_model_event.set(event)
165
160
  try:
166
161
  yield
167
162
  finally:
168
- reset_active_sample_retries()
169
-
170
-
171
- def reset_active_sample_retries() -> None:
172
- active = sample_active()
173
- if active:
174
- active.retry_count = 0
163
+ _active_model_event.reset(token)
175
164
 
176
165
 
177
166
  def report_active_sample_retry() -> None:
178
- active = sample_active()
179
- if active:
180
- # only do this for the top level subtask
181
- if transcript().name == SAMPLE_SUBTASK:
182
- active.retry_count = active.retry_count + 1
167
+ model_event = _active_model_event.get()
168
+ if model_event is not None:
169
+ if model_event.retries is None:
170
+ model_event.retries = 0
171
+ model_event.retries = model_event.retries + 1
183
172
 
184
173
 
185
174
  _sample_active: ContextVar[ActiveSample | None] = ContextVar(
@@ -14,12 +14,19 @@ from typing import (
14
14
  Union,
15
15
  )
16
16
 
17
- from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_serializer
17
+ from pydantic import (
18
+ BaseModel,
19
+ ConfigDict,
20
+ Field,
21
+ JsonValue,
22
+ field_serializer,
23
+ )
18
24
  from shortuuid import uuid
19
25
 
20
- from inspect_ai._util.constants import SAMPLE_SUBTASK
26
+ from inspect_ai._util.constants import DESERIALIZING
21
27
  from inspect_ai._util.error import EvalError
22
- from inspect_ai._util.json import JsonChange, json_changes
28
+ from inspect_ai._util.json import JsonChange
29
+ from inspect_ai._util.logger import warn_once
23
30
  from inspect_ai._util.working import sample_working_time
24
31
  from inspect_ai.dataset._dataset import Sample
25
32
  from inspect_ai.log._message import LoggingMessage
@@ -28,7 +35,6 @@ from inspect_ai.model._generate_config import GenerateConfig
28
35
  from inspect_ai.model._model_call import ModelCall
29
36
  from inspect_ai.model._model_output import ModelOutput
30
37
  from inspect_ai.scorer._metric import Score
31
- from inspect_ai.solver._task_state import state_jsonable
32
38
  from inspect_ai.tool._tool import ToolResult
33
39
  from inspect_ai.tool._tool_call import (
34
40
  ToolCall,
@@ -38,6 +44,7 @@ from inspect_ai.tool._tool_call import (
38
44
  )
39
45
  from inspect_ai.tool._tool_choice import ToolChoice
40
46
  from inspect_ai.tool._tool_info import ToolInfo
47
+ from inspect_ai.util._span import current_span_id
41
48
  from inspect_ai.util._store import store, store_changes, store_jsonable
42
49
 
43
50
  logger = getLogger(__name__)
@@ -51,6 +58,9 @@ class BaseEvent(BaseModel):
51
58
  }
52
59
  id_: str = Field(default_factory=lambda: str(uuid()), exclude=True)
53
60
 
61
+ span_id: str | None = Field(default=None)
62
+ """Span the event occurred within."""
63
+
54
64
  timestamp: datetime = Field(default_factory=datetime.now)
55
65
  """Clock time at which event occurred."""
56
66
 
@@ -60,6 +70,17 @@ class BaseEvent(BaseModel):
60
70
  pending: bool | None = Field(default=None)
61
71
  """Is this event pending?"""
62
72
 
73
+ def model_post_init(self, __context: Any) -> None:
74
+ # check if deserializing
75
+ is_deserializing = isinstance(__context, dict) and __context.get(
76
+ DESERIALIZING, False
77
+ )
78
+
79
+ # Generate context id fields if not deserializing
80
+ if not is_deserializing:
81
+ if self.span_id is None:
82
+ self.span_id = current_span_id()
83
+
63
84
  @field_serializer("timestamp")
64
85
  def serialize_timestamp(self, dt: datetime) -> str:
65
86
  return dt.astimezone().isoformat()
@@ -141,6 +162,9 @@ class ModelEvent(BaseEvent):
141
162
  output: ModelOutput
142
163
  """Output from model."""
143
164
 
165
+ retries: int | None = Field(default=None)
166
+ """Retries for the model API request."""
167
+
144
168
  error: str | None = Field(default=None)
145
169
  """Error which occurred during model call."""
146
170
 
@@ -197,7 +221,13 @@ class ToolEvent(BaseEvent):
197
221
  """Error that occurred during tool call."""
198
222
 
199
223
  events: list["Event"] = Field(default_factory=list)
200
- """Transcript of events for tool."""
224
+ """Transcript of events for tool.
225
+
226
+ Note that events are no longer recorded separately within
227
+ tool events but rather all events are recorded in the main
228
+ transcript. This field is deprecated and here for backwards
229
+ compatibility with transcripts that have sub-events.
230
+ """
201
231
 
202
232
  completed: datetime | None = Field(default=None)
203
233
  """Time that tool call completed (see `timestamp` for started)"""
@@ -216,7 +246,6 @@ class ToolEvent(BaseEvent):
216
246
  result: ToolResult,
217
247
  truncated: tuple[int, int] | None,
218
248
  error: ToolCallError | None,
219
- events: list["Event"],
220
249
  waiting_time: float,
221
250
  agent: str | None,
222
251
  failed: bool | None,
@@ -224,7 +253,6 @@ class ToolEvent(BaseEvent):
224
253
  self.result = result
225
254
  self.truncated = truncated
226
255
  self.error = error
227
- self.events = events
228
256
  self.pending = None
229
257
  completed = datetime.now()
230
258
  self.completed = completed
@@ -396,6 +424,35 @@ class ScoreEvent(BaseEvent):
396
424
  """Was this an intermediate scoring?"""
397
425
 
398
426
 
427
+ class SpanBeginEvent(BaseEvent):
428
+ """Mark the beginning of a transcript span."""
429
+
430
+ event: Literal["span_begin"] = Field(default="span_begin")
431
+ """Event type."""
432
+
433
+ id: str
434
+ """Unique identifier for span."""
435
+
436
+ parent_id: str | None = Field(default=None)
437
+ """Identifier for parent span."""
438
+
439
+ type: str | None = Field(default=None)
440
+ """Optional 'type' field for span."""
441
+
442
+ name: str
443
+ """Span name."""
444
+
445
+
446
+ class SpanEndEvent(BaseEvent):
447
+ """Mark the end of a transcript span."""
448
+
449
+ event: Literal["span_end"] = Field(default="span_end")
450
+ """Event type."""
451
+
452
+ id: str
453
+ """Unique identifier for span."""
454
+
455
+
399
456
  class StepEvent(BaseEvent):
400
457
  """Step within current sample or subtask."""
401
458
 
@@ -431,7 +488,13 @@ class SubtaskEvent(BaseEvent):
431
488
  """Subtask function result."""
432
489
 
433
490
  events: list["Event"] = Field(default_factory=list)
434
- """Transcript of events for subtask."""
491
+ """Transcript of events for subtask.
492
+
493
+ Note that events are no longer recorded separately within
494
+ subtasks but rather all events are recorded in the main
495
+ transcript. This field is deprecated and here for backwards
496
+ compatibility with transcripts that have sub-events.
497
+ """
435
498
 
436
499
  completed: datetime | None = Field(default=None)
437
500
  """Time that subtask completed (see `timestamp` for started)"""
@@ -461,6 +524,8 @@ Event: TypeAlias = Union[
461
524
  | ErrorEvent
462
525
  | LoggerEvent
463
526
  | InfoEvent
527
+ | SpanBeginEvent
528
+ | SpanEndEvent
464
529
  | StepEvent
465
530
  | SubtaskEvent,
466
531
  ]
@@ -474,8 +539,7 @@ class Transcript:
474
539
 
475
540
  _event_logger: Callable[[Event], None] | None
476
541
 
477
- def __init__(self, name: str = "") -> None:
478
- self.name = name
542
+ def __init__(self) -> None:
479
543
  self._event_logger = None
480
544
  self._events: list[Event] = []
481
545
 
@@ -492,19 +556,20 @@ class Transcript:
492
556
  def step(self, name: str, type: str | None = None) -> Iterator[None]:
493
557
  """Context manager for recording StepEvent.
494
558
 
559
+ The `step()` context manager is deprecated and will be removed in a future version.
560
+ Please use the `span()` context manager instead.
561
+
495
562
  Args:
496
563
  name (str): Step name.
497
564
  type (str | None): Optional step type.
498
565
  """
499
- # step event
500
- self._event(StepEvent(action="begin", name=name, type=type))
501
-
502
- # run the step (tracking state/store changes)
503
- with track_state_changes(type), track_store_changes():
504
- yield
505
-
506
- # end step event
507
- self._event(StepEvent(action="end", name=name, type=type))
566
+ warn_once(
567
+ logger,
568
+ "The `transcript().step()` context manager is deprecated and will "
569
+ + "be removed in a future version. Please replace the call to step() "
570
+ + "with a call to span().",
571
+ )
572
+ yield
508
573
 
509
574
  @property
510
575
  def events(self) -> Sequence[Event]:
@@ -545,23 +610,6 @@ def track_store_changes() -> Iterator[None]:
545
610
  transcript()._event(StoreEvent(changes=changes))
546
611
 
547
612
 
548
- @contextlib.contextmanager
549
- def track_state_changes(type: str | None = None) -> Iterator[None]:
550
- # we only want to track for step() inside the the sample
551
- # (solver level tracking is handled already and there are
552
- # no state changes in subtasks)
553
- if transcript().name == SAMPLE_SUBTASK and type != "solver":
554
- before = state_jsonable()
555
- yield
556
- after = state_jsonable()
557
-
558
- changes = json_changes(before, after)
559
- if changes:
560
- transcript()._event(StateEvent(changes=changes))
561
- else:
562
- yield
563
-
564
-
565
613
  def init_transcript(transcript: Transcript) -> None:
566
614
  _transcript.set(transcript)
567
615