inspect-ai 0.3.98__py3-none-any.whl → 0.3.100__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. inspect_ai/__init__.py +2 -0
  2. inspect_ai/_cli/log.py +1 -1
  3. inspect_ai/_display/core/config.py +11 -5
  4. inspect_ai/_display/core/panel.py +66 -2
  5. inspect_ai/_display/core/textual.py +5 -2
  6. inspect_ai/_display/plain/display.py +1 -0
  7. inspect_ai/_display/rich/display.py +2 -2
  8. inspect_ai/_display/textual/widgets/transcript.py +41 -1
  9. inspect_ai/_eval/run.py +12 -4
  10. inspect_ai/_eval/score.py +2 -4
  11. inspect_ai/_eval/task/log.py +1 -1
  12. inspect_ai/_eval/task/run.py +59 -81
  13. inspect_ai/_eval/task/task.py +1 -1
  14. inspect_ai/_util/_async.py +1 -1
  15. inspect_ai/_util/content.py +11 -6
  16. inspect_ai/_util/interrupt.py +2 -2
  17. inspect_ai/_util/text.py +7 -0
  18. inspect_ai/_util/working.py +8 -37
  19. inspect_ai/_view/__init__.py +0 -0
  20. inspect_ai/_view/schema.py +3 -1
  21. inspect_ai/_view/view.py +14 -0
  22. inspect_ai/_view/www/CLAUDE.md +15 -0
  23. inspect_ai/_view/www/dist/assets/index.css +273 -169
  24. inspect_ai/_view/www/dist/assets/index.js +20079 -17019
  25. inspect_ai/_view/www/log-schema.json +122 -8
  26. inspect_ai/_view/www/package.json +5 -1
  27. inspect_ai/_view/www/src/@types/log.d.ts +20 -2
  28. inspect_ai/_view/www/src/app/App.tsx +1 -15
  29. inspect_ai/_view/www/src/app/appearance/icons.ts +4 -1
  30. inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +24 -6
  31. inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +0 -5
  32. inspect_ai/_view/www/src/app/content/RenderedContent.tsx +221 -205
  33. inspect_ai/_view/www/src/app/log-view/LogViewContainer.tsx +2 -1
  34. inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +5 -0
  35. inspect_ai/_view/www/src/app/routing/url.ts +84 -4
  36. inspect_ai/_view/www/src/app/samples/InlineSampleDisplay.module.css +0 -5
  37. inspect_ai/_view/www/src/app/samples/SampleDialog.module.css +1 -1
  38. inspect_ai/_view/www/src/app/samples/SampleDisplay.module.css +7 -0
  39. inspect_ai/_view/www/src/app/samples/SampleDisplay.tsx +26 -19
  40. inspect_ai/_view/www/src/app/samples/SampleSummaryView.module.css +1 -2
  41. inspect_ai/_view/www/src/app/samples/chat/ChatMessage.tsx +8 -6
  42. inspect_ai/_view/www/src/app/samples/chat/ChatMessageRow.tsx +0 -4
  43. inspect_ai/_view/www/src/app/samples/chat/ChatViewVirtualList.tsx +3 -2
  44. inspect_ai/_view/www/src/app/samples/chat/MessageContent.tsx +2 -0
  45. inspect_ai/_view/www/src/app/samples/chat/MessageContents.tsx +2 -0
  46. inspect_ai/_view/www/src/app/samples/chat/messages.ts +1 -0
  47. inspect_ai/_view/www/src/app/samples/chat/tools/ToolCallView.tsx +1 -0
  48. inspect_ai/_view/www/src/app/samples/list/SampleRow.tsx +1 -1
  49. inspect_ai/_view/www/src/app/samples/scores/SampleScoresGrid.module.css +2 -2
  50. inspect_ai/_view/www/src/app/samples/transcript/ErrorEventView.tsx +2 -3
  51. inspect_ai/_view/www/src/app/samples/transcript/InfoEventView.tsx +1 -1
  52. inspect_ai/_view/www/src/app/samples/transcript/InputEventView.tsx +1 -2
  53. inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.module.css +1 -1
  54. inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.tsx +1 -1
  55. inspect_ai/_view/www/src/app/samples/transcript/SampleInitEventView.tsx +1 -1
  56. inspect_ai/_view/www/src/app/samples/transcript/SampleLimitEventView.tsx +3 -2
  57. inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.tsx +4 -5
  58. inspect_ai/_view/www/src/app/samples/transcript/ScoreEventView.tsx +1 -1
  59. inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +1 -2
  60. inspect_ai/_view/www/src/app/samples/transcript/StepEventView.tsx +1 -3
  61. inspect_ai/_view/www/src/app/samples/transcript/SubtaskEventView.tsx +1 -2
  62. inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +3 -4
  63. inspect_ai/_view/www/src/app/samples/transcript/TranscriptPanel.module.css +42 -0
  64. inspect_ai/_view/www/src/app/samples/transcript/TranscriptPanel.tsx +77 -0
  65. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualList.tsx +27 -71
  66. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +13 -3
  67. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.tsx +27 -2
  68. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.module.css +1 -0
  69. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +21 -22
  70. inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.module.css +45 -0
  71. inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.tsx +223 -0
  72. inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.module.css +10 -0
  73. inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.tsx +258 -0
  74. inspect_ai/_view/www/src/app/samples/transcript/outline/tree-visitors.ts +187 -0
  75. inspect_ai/_view/www/src/app/samples/transcript/state/StateEventRenderers.tsx +8 -1
  76. inspect_ai/_view/www/src/app/samples/transcript/state/StateEventView.tsx +3 -4
  77. inspect_ai/_view/www/src/app/samples/transcript/transform/hooks.ts +78 -0
  78. inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +340 -135
  79. inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +3 -0
  80. inspect_ai/_view/www/src/app/samples/transcript/types.ts +2 -0
  81. inspect_ai/_view/www/src/app/types.ts +5 -1
  82. inspect_ai/_view/www/src/client/api/api-browser.ts +2 -2
  83. inspect_ai/_view/www/src/components/LiveVirtualList.tsx +6 -1
  84. inspect_ai/_view/www/src/components/MarkdownDiv.tsx +1 -1
  85. inspect_ai/_view/www/src/components/PopOver.tsx +422 -0
  86. inspect_ai/_view/www/src/components/PulsingDots.module.css +9 -9
  87. inspect_ai/_view/www/src/components/PulsingDots.tsx +4 -1
  88. inspect_ai/_view/www/src/components/StickyScroll.tsx +183 -0
  89. inspect_ai/_view/www/src/components/TabSet.tsx +4 -0
  90. inspect_ai/_view/www/src/state/hooks.ts +52 -2
  91. inspect_ai/_view/www/src/state/logSlice.ts +4 -3
  92. inspect_ai/_view/www/src/state/samplePolling.ts +8 -0
  93. inspect_ai/_view/www/src/state/sampleSlice.ts +53 -9
  94. inspect_ai/_view/www/src/state/scrolling.ts +152 -0
  95. inspect_ai/_view/www/src/utils/attachments.ts +7 -0
  96. inspect_ai/_view/www/src/utils/python.ts +18 -0
  97. inspect_ai/_view/www/yarn.lock +269 -6
  98. inspect_ai/agent/_react.py +12 -7
  99. inspect_ai/agent/_run.py +46 -11
  100. inspect_ai/analysis/beta/_dataframe/samples/table.py +19 -18
  101. inspect_ai/log/_bundle.py +5 -3
  102. inspect_ai/log/_log.py +3 -3
  103. inspect_ai/log/_recorders/file.py +2 -9
  104. inspect_ai/log/_transcript.py +1 -1
  105. inspect_ai/model/_call_tools.py +6 -2
  106. inspect_ai/model/_openai.py +1 -1
  107. inspect_ai/model/_openai_responses.py +78 -39
  108. inspect_ai/model/_openai_web_search.py +31 -0
  109. inspect_ai/model/_providers/anthropic.py +3 -6
  110. inspect_ai/model/_providers/azureai.py +72 -3
  111. inspect_ai/model/_providers/openai.py +2 -1
  112. inspect_ai/model/_providers/providers.py +1 -1
  113. inspect_ai/scorer/_metric.py +1 -2
  114. inspect_ai/solver/_task_state.py +2 -2
  115. inspect_ai/tool/_tool.py +6 -2
  116. inspect_ai/tool/_tool_def.py +27 -4
  117. inspect_ai/tool/_tool_info.py +2 -0
  118. inspect_ai/tool/_tools/_web_search/_google.py +15 -4
  119. inspect_ai/tool/_tools/_web_search/_tavily.py +35 -12
  120. inspect_ai/tool/_tools/_web_search/_web_search.py +214 -45
  121. inspect_ai/util/__init__.py +6 -0
  122. inspect_ai/util/_json.py +3 -0
  123. inspect_ai/util/_limit.py +374 -141
  124. inspect_ai/util/_sandbox/docker/compose.py +20 -11
  125. inspect_ai/util/_span.py +1 -1
  126. {inspect_ai-0.3.98.dist-info → inspect_ai-0.3.100.dist-info}/METADATA +3 -3
  127. {inspect_ai-0.3.98.dist-info → inspect_ai-0.3.100.dist-info}/RECORD +131 -117
  128. {inspect_ai-0.3.98.dist-info → inspect_ai-0.3.100.dist-info}/WHEEL +1 -1
  129. {inspect_ai-0.3.98.dist-info → inspect_ai-0.3.100.dist-info}/entry_points.txt +0 -0
  130. {inspect_ai-0.3.98.dist-info → inspect_ai-0.3.100.dist-info}/licenses/LICENSE +0 -0
  131. {inspect_ai-0.3.98.dist-info → inspect_ai-0.3.100.dist-info}/top_level.txt +0 -0
inspect_ai/__init__.py CHANGED
@@ -10,6 +10,7 @@ from inspect_ai._eval.score import score, score_async
10
10
  from inspect_ai._eval.task import Epochs, Task, TaskInfo, task_with
11
11
  from inspect_ai._eval.task.tasks import Tasks
12
12
  from inspect_ai._util.constants import PKG_NAME
13
+ from inspect_ai._view.view import view
13
14
  from inspect_ai.agent._human.agent import human_cli
14
15
  from inspect_ai.solver._human_agent import human_agent
15
16
 
@@ -32,4 +33,5 @@ __all__ = [
32
33
  "TaskInfo",
33
34
  "task",
34
35
  "task_with",
36
+ "view",
35
37
  ]
inspect_ai/_cli/log.py CHANGED
@@ -199,6 +199,6 @@ def view_resource(file: str) -> str:
199
199
 
200
200
 
201
201
  def view_type_resource(file: str) -> str:
202
- resource = PKG_PATH / "_view" / "www" / "src" / "types" / file
202
+ resource = PKG_PATH / "_view" / "www" / "src" / "@types" / file
203
203
  with open(resource, "r", encoding="utf-8") as f:
204
204
  return f.read()
@@ -1,4 +1,8 @@
1
+ from rich.console import RenderableType
2
+ from rich.text import Text
3
+
1
4
  from inspect_ai._util.registry import is_model_dict, is_registry_dict
5
+ from inspect_ai._util.text import truncate_text
2
6
  from inspect_ai.log._log import eval_config_defaults
3
7
 
4
8
  from .display import TaskProfile
@@ -6,7 +10,7 @@ from .display import TaskProfile
6
10
 
7
11
  def task_config(
8
12
  profile: TaskProfile, generate_config: bool = True, style: str = ""
9
- ) -> str:
13
+ ) -> RenderableType:
10
14
  # merge config
11
15
  # wind params back for display
12
16
  task_args = dict(profile.task_args)
@@ -39,15 +43,17 @@ def task_config(
39
43
  elif name not in ["limit", "model", "response_schema", "log_shared"]:
40
44
  if isinstance(value, list):
41
45
  value = ",".join([str(v) for v in value])
46
+ elif isinstance(value, dict):
47
+ value = "{...}"
42
48
  if isinstance(value, str):
49
+ value = truncate_text(value, 50)
43
50
  value = value.replace("[", "\\[")
44
51
  config_print.append(f"{name}: {value}")
45
52
  values = ", ".join(config_print)
46
53
  if values:
47
- if style:
48
- return f"[{style}]{values}[/{style}]"
49
- else:
50
- return values
54
+ values_text = Text(values, style=style)
55
+ values_text.truncate(500, overflow="ellipsis")
56
+ return values_text
51
57
  else:
52
58
  return ""
53
59
 
@@ -9,6 +9,7 @@ from rich.text import Text
9
9
  from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH
10
10
  from inspect_ai._util.path import cwd_relative_path
11
11
  from inspect_ai._util.registry import registry_unqualified_name
12
+ from inspect_ai.util._display import display_type
12
13
 
13
14
  from .display import TaskProfile
14
15
  from .rich import is_vscode_notebook, rich_theme
@@ -24,7 +25,13 @@ def task_panel(
24
25
  | None,
25
26
  footer: RenderableType | tuple[RenderableType, RenderableType] | None,
26
27
  log_location: str | None,
27
- ) -> Panel:
28
+ ) -> RenderableType:
29
+ # dispatch to plain handler if we are in plain mode
30
+ if display_type() == "plain":
31
+ return task_panel_plain(
32
+ profile, show_model, body, subtitle, footer, log_location
33
+ )
34
+
28
35
  # rendering context
29
36
  theme = rich_theme()
30
37
  console = rich.get_console()
@@ -93,7 +100,7 @@ def task_panel(
93
100
  # create panel w/ title
94
101
  panel = Panel(
95
102
  root,
96
- title=f"[bold][{theme.meta}]{task_title(profile, show_model)}[/{theme.meta}][/bold]",
103
+ title=task_panel_title(profile, show_model),
97
104
  title_align="left",
98
105
  width=width,
99
106
  expand=True,
@@ -101,6 +108,63 @@ def task_panel(
101
108
  return panel
102
109
 
103
110
 
111
+ def task_panel_plain(
112
+ profile: TaskProfile,
113
+ show_model: bool,
114
+ body: RenderableType,
115
+ subtitle: RenderableType
116
+ | str
117
+ | Tuple[RenderableType | str, RenderableType | str]
118
+ | None,
119
+ footer: RenderableType | tuple[RenderableType, RenderableType] | None,
120
+ log_location: str | None,
121
+ ) -> RenderableType:
122
+ # delimiter text
123
+ delimeter = "---------------------------------------------------------"
124
+
125
+ # root table for output
126
+ table = Table.grid(expand=False)
127
+ table.add_column()
128
+ table.add_row(delimeter)
129
+
130
+ # title and subtitle
131
+ table.add_row(task_panel_title(profile, show_model))
132
+ if isinstance(subtitle, tuple):
133
+ subtitle = subtitle[0]
134
+ table.add_row(subtitle)
135
+
136
+ # task info
137
+ if body:
138
+ table.add_row(body)
139
+
140
+ # footer
141
+ if isinstance(footer, tuple):
142
+ footer = footer[0]
143
+ if footer:
144
+ table.add_row(footer)
145
+
146
+ # log location
147
+ if log_location:
148
+ # Print a cwd relative path
149
+ try:
150
+ log_location_relative = cwd_relative_path(log_location, walk_up=True)
151
+ except ValueError:
152
+ log_location_relative = log_location
153
+ table.add_row(f"Log: {log_location_relative}")
154
+
155
+ table.add_row(delimeter)
156
+ table.add_row("")
157
+
158
+ return table
159
+
160
+
161
+ def task_panel_title(profile: TaskProfile, show_model: bool) -> str:
162
+ theme = rich_theme()
163
+ return (
164
+ f"[bold][{theme.meta}]{task_title(profile, show_model)}[/{theme.meta}][/bold]"
165
+ )
166
+
167
+
104
168
  def to_renderable(item: RenderableType | str, style: str = "") -> RenderableType:
105
169
  if isinstance(item, str):
106
170
  return Text.from_markup(item, style=style)
@@ -8,8 +8,6 @@ logger = getLogger(__name__)
8
8
  # force mouse support for textual -- this works around an issue where
9
9
  # mouse events are disabled after a reload of the vs code ide, see:
10
10
  # https://github.com/Textualize/textual/issues/5380
11
- # ansi codes for enabling mouse support are idempotent so it is fine
12
- # to do this even in cases where mouse support is already enabled.
13
11
  # we try/catch since we aren't 100% sure there aren't cases where doing
14
12
  # this won't raise and we'd rather not fail hard in in these case
15
13
  def textual_enable_mouse_support(driver: Driver) -> None:
@@ -17,5 +15,10 @@ def textual_enable_mouse_support(driver: Driver) -> None:
17
15
  if enable_mouse_support:
18
16
  try:
19
17
  enable_mouse_support()
18
+ # Re-enable SGR-Pixels format if it was previously enabled.
19
+ # See #1943.
20
+ enable_mouse_pixels = getattr(driver, "_enable_mouse_pixels", None)
21
+ if enable_mouse_pixels and getattr(driver, "_mouse_pixels", False):
22
+ enable_mouse_pixels()
20
23
  except Exception as ex:
21
24
  logger.warning(f"Error enabling mouse support: {ex}")
@@ -208,3 +208,4 @@ class PlainTaskDisplay(TaskDisplay):
208
208
  def complete(self, result: TaskResult) -> None:
209
209
  self.task.result = result
210
210
  self._print_status()
211
+ print("")
@@ -341,8 +341,6 @@ def tasks_live_status(
341
341
 
342
342
  # get config
343
343
  config = task_config(tasks[0].profile, generate_config=False, style=theme.light)
344
- if config:
345
- config += "\n"
346
344
 
347
345
  # build footer table
348
346
  footer_table = Table.grid(expand=True)
@@ -356,6 +354,8 @@ def tasks_live_status(
356
354
  layout_table = Table.grid(expand=True)
357
355
  layout_table.add_column()
358
356
  layout_table.add_row(config)
357
+ if config:
358
+ layout_table.add_row("")
359
359
  layout_table.add_row(progress)
360
360
  layout_table.add_row(footer_table)
361
361
 
@@ -84,6 +84,7 @@ class TranscriptView(ScrollableContainer):
84
84
  scroll_to_end = (
85
85
  new_sample or abs(self.scroll_y - self.max_scroll_y) <= 20
86
86
  )
87
+
87
88
  async with self.batch():
88
89
  await self.remove_children()
89
90
  await self.mount_all(
@@ -100,9 +101,32 @@ class TranscriptView(ScrollableContainer):
100
101
  else:
101
102
  self._pending_sample = sample
102
103
 
103
- def _widgets_for_events(self, events: Sequence[Event]) -> list[Widget]:
104
+ def _widgets_for_events(
105
+ self, events: Sequence[Event], limit: int = 10
106
+ ) -> list[Widget]:
104
107
  widgets: list[Widget] = []
108
+
109
+ # filter the events to the <limit> most recent
110
+ filtered_events = events
111
+ if len(events) > limit:
112
+ filtered_events = filtered_events[-limit:]
113
+
114
+ # find the sample init event
115
+ sample_init: SampleInitEvent | None = None
105
116
  for event in events:
117
+ if isinstance(event, SampleInitEvent):
118
+ sample_init = event
119
+ break
120
+
121
+ # add the sample init event if it isn't already in the event list
122
+ if sample_init and sample_init not in filtered_events:
123
+ filtered_events = [sample_init] + list(filtered_events)
124
+
125
+ # compute how many events we filtered out
126
+ filtered_count = len(events) - len(filtered_events)
127
+ showed_filtered_count = False
128
+
129
+ for event in filtered_events:
106
130
  display = render_event(event)
107
131
  if display:
108
132
  for d in display:
@@ -118,6 +142,22 @@ class TranscriptView(ScrollableContainer):
118
142
  set_transcript_markdown_options(d.content)
119
143
  widgets.append(Static(d.content, markup=False))
120
144
  widgets.append(Static(Text(" ")))
145
+
146
+ if not showed_filtered_count and filtered_count > 0:
147
+ showed_filtered_count = True
148
+
149
+ widgets.append(
150
+ Static(
151
+ transcript_separator(
152
+ f"{filtered_count} events..."
153
+ if filtered_count > 1
154
+ else "1 event...",
155
+ self.app.current_theme.primary,
156
+ )
157
+ )
158
+ )
159
+ widgets.append(Static(Text(" ")))
160
+
121
161
  return widgets
122
162
 
123
163
 
inspect_ai/_eval/run.py CHANGED
@@ -298,10 +298,13 @@ async def run_multiple(tasks: list[TaskRunOptions], parallel: int) -> list[EvalL
298
298
 
299
299
  # setup pending tasks, queue, and results
300
300
  pending_tasks = tasks.copy()
301
- results: list[EvalLog] = []
301
+ results: list[tuple[int, EvalLog]] = []
302
302
  tasks_completed = 0
303
303
  total_tasks = len(tasks)
304
304
 
305
+ # Create a mapping from task to its original index
306
+ task_to_original_index = {id(task): i for i, task in enumerate(tasks)}
307
+
305
308
  # produce/consume tasks
306
309
  send_channel, receive_channel = anyio.create_memory_object_stream[TaskRunOptions](
307
310
  parallel * 2
@@ -322,7 +325,7 @@ async def run_multiple(tasks: list[TaskRunOptions], parallel: int) -> list[EvalL
322
325
  # among those models, pick one with the least usage
323
326
  model = min(models_with_pending, key=lambda m: model_counts[m])
324
327
 
325
- # now we know theres at least one pending task for this model so its safe to pick it
328
+ # now we know there's at least one pending task for this model so it's safe to pick it
326
329
  next_task = next(t for t in pending_tasks if str(t.model) == model)
327
330
  pending_tasks.remove(next_task)
328
331
  model_counts[str(next_task.model)] += 1
@@ -339,6 +342,8 @@ async def run_multiple(tasks: list[TaskRunOptions], parallel: int) -> list[EvalL
339
342
  nonlocal tasks_completed
340
343
  async for task_options in receive_channel:
341
344
  result: EvalLog | None = None
345
+ # Get the original index of this task
346
+ original_index = task_to_original_index[id(task_options)]
342
347
 
343
348
  # run the task
344
349
  try:
@@ -354,11 +359,13 @@ async def run_multiple(tasks: list[TaskRunOptions], parallel: int) -> list[EvalL
354
359
  # see: https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result
355
360
  def create_task_runner(
356
361
  options: TaskRunOptions = task_options,
362
+ idx: int = original_index,
357
363
  ) -> Callable[[], Awaitable[None]]:
358
364
  async def run_task() -> None:
359
365
  nonlocal result
360
366
  result = await task_run(options)
361
- results.append(result)
367
+ # Store result with its original index
368
+ results.append((idx, result))
362
369
 
363
370
  return run_task
364
371
 
@@ -426,7 +433,8 @@ async def run_multiple(tasks: list[TaskRunOptions], parallel: int) -> list[EvalL
426
433
 
427
434
  clear_task_screen()
428
435
 
429
- return results
436
+ # Sort results by original index and return just the values
437
+ return [r for _, r in sorted(results)]
430
438
 
431
439
 
432
440
  def resolve_task_sample_ids(
inspect_ai/_eval/score.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import functools
2
2
  from copy import deepcopy
3
3
  from pathlib import Path
4
- from typing import Any, Callable, Literal, cast
4
+ from typing import Any, Callable, Literal
5
5
 
6
6
  import anyio
7
7
 
@@ -270,9 +270,7 @@ def metrics_from_log(log: EvalLog) -> list[Metric] | dict[str, list[Metric]] | N
270
270
 
271
271
 
272
272
  def metric_from_log(metric: EvalMetricDefinition) -> Metric:
273
- return cast(
274
- Metric, registry_create("metric", metric.name, **(metric.options or {}))
275
- )
273
+ return registry_create("metric", metric.name, **(metric.options or {}))
276
274
 
277
275
 
278
276
  def reducers_from_log(log: EvalLog) -> list[ScoreReducer] | None:
@@ -56,7 +56,7 @@ class TaskLogger:
56
56
  def __init__(
57
57
  self,
58
58
  task_name: str,
59
- task_version: int,
59
+ task_version: int | str,
60
60
  task_file: str | None,
61
61
  task_registry_name: str | None,
62
62
  task_id: str | None,
@@ -35,11 +35,7 @@ from inspect_ai._util.registry import (
35
35
  registry_log_name,
36
36
  registry_unqualified_name,
37
37
  )
38
- from inspect_ai._util.working import (
39
- end_sample_working_limit,
40
- init_sample_working_limit,
41
- sample_waiting_time,
42
- )
38
+ from inspect_ai._util.working import init_sample_working_time, sample_waiting_time
43
39
  from inspect_ai._view.notify import view_notify_eval
44
40
  from inspect_ai.dataset import Dataset, Sample
45
41
  from inspect_ai.log import (
@@ -90,6 +86,8 @@ from inspect_ai.solver._fork import set_task_generate
90
86
  from inspect_ai.solver._solver import Solver
91
87
  from inspect_ai.solver._task_state import sample_state, set_sample_state, state_jsonable
92
88
  from inspect_ai.util._limit import LimitExceededError
89
+ from inspect_ai.util._limit import time_limit as create_time_limit
90
+ from inspect_ai.util._limit import working_limit as create_working_limit
93
91
  from inspect_ai.util._sandbox.context import sandbox_connections
94
92
  from inspect_ai.util._sandbox.environment import SandboxEnvironmentSpec
95
93
  from inspect_ai.util._span import span
@@ -635,10 +633,6 @@ async def task_run_sample(
635
633
  )
636
634
 
637
635
  async with sandboxenv_cm:
638
- timeout_cm: (
639
- contextlib._GeneratorContextManager[anyio.CancelScope]
640
- | contextlib.nullcontext[None]
641
- ) = contextlib.nullcontext()
642
636
  try:
643
637
  # update active sample wth sandboxes now that we are initialised
644
638
  # (ensure that we still exit init context in presence of sandbox error)
@@ -647,19 +641,17 @@ async def task_run_sample(
647
641
  finally:
648
642
  await init_span.__aexit__(None, None, None)
649
643
 
650
- # initialise timeout context manager
651
- timeout_cm = (
652
- anyio.fail_after(time_limit)
653
- if time_limit is not None
654
- else contextlib.nullcontext()
655
- )
656
-
657
644
  # record start time
658
645
  start_time = time.monotonic()
659
- init_sample_working_limit(start_time, working_limit)
660
-
661
- # run sample w/ optional timeout
662
- with timeout_cm, state._token_limit, state._message_limit:
646
+ init_sample_working_time(start_time)
647
+
648
+ # run sample w/ optional limits
649
+ with (
650
+ state._token_limit,
651
+ state._message_limit,
652
+ create_time_limit(time_limit),
653
+ create_working_limit(working_limit),
654
+ ):
663
655
  # mark started
664
656
  active.started = datetime.now().timestamp()
665
657
 
@@ -675,24 +667,15 @@ async def task_run_sample(
675
667
  )
676
668
 
677
669
  # set progress for plan then run it
678
- state = await plan(state, generate)
679
-
680
- # disable sample working limit after execution
681
- end_sample_working_limit()
670
+ async with span("solvers"):
671
+ state = await plan(state, generate)
682
672
 
683
673
  except TimeoutError:
684
- if time_limit is not None:
685
- transcript()._event(
686
- SampleLimitEvent(
687
- type="time",
688
- message=f"Sample completed: exceeded time limit ({time_limit:,} seconds)",
689
- limit=time_limit,
690
- )
691
- )
692
- else:
693
- py_logger.warning(
694
- "Unexpected timeout error reached top of sample stack. Are you handling TimeoutError when applying timeouts?"
695
- )
674
+ # Scoped time limits manifest themselves as LimitExceededError, not
675
+ # TimeoutError.
676
+ py_logger.warning(
677
+ "Unexpected timeout error reached top of sample stack. Are you handling TimeoutError when applying timeouts?"
678
+ )
696
679
 
697
680
  # capture most recent state for scoring
698
681
  state = sample_state() or state
@@ -737,54 +720,59 @@ async def task_run_sample(
737
720
  # the cause of the timeout is a hung container and scoring requires
738
721
  # interacting with the container). as a middle ground we use half
739
722
  # of the original timeout value for scoring.
740
- if time_limit is not None:
741
- timeout_cm = anyio.fail_after(time_limit / 2)
723
+ scoring_time_limit = time_limit / 2 if time_limit else None
742
724
 
743
725
  set_sample_state(state)
744
726
 
745
727
  # scoring
746
728
  try:
747
729
  # timeout during scoring will result in an ordinary sample error
748
- with timeout_cm:
730
+ with create_time_limit(scoring_time_limit):
749
731
  if error is None:
750
- for scorer in scorers or []:
751
- scorer_name = unique_scorer_name(
752
- scorer, list(results.keys())
753
- )
754
- async with span(name=scorer_name, type="scorer"):
755
- score_result = (
756
- await scorer(state, Target(sample.target))
757
- if scorer
758
- else None
732
+ async with span(name="scorers"):
733
+ for scorer in scorers or []:
734
+ scorer_name = unique_scorer_name(
735
+ scorer, list(results.keys())
759
736
  )
760
- if score_result is not None:
761
- sample_score = SampleScore(
762
- score=score_result,
763
- sample_id=sample.id,
764
- sample_metadata=sample.metadata,
765
- scorer=registry_unqualified_name(scorer),
737
+ async with span(name=scorer_name, type="scorer"):
738
+ score_result = (
739
+ await scorer(state, Target(sample.target))
740
+ if scorer
741
+ else None
742
+ )
743
+ if score_result is not None:
744
+ sample_score = SampleScore(
745
+ score=score_result,
746
+ sample_id=sample.id,
747
+ sample_metadata=sample.metadata,
748
+ scorer=registry_unqualified_name(
749
+ scorer
750
+ ),
751
+ )
752
+ transcript()._event(
753
+ ScoreEvent(
754
+ score=score_result,
755
+ target=sample.target,
756
+ )
757
+ )
758
+ results[scorer_name] = sample_score
759
+
760
+ # add scores returned by solvers
761
+ if state.scores is not None:
762
+ for name, score in state.scores.items():
763
+ results[name] = SampleScore(
764
+ score=score,
765
+ sample_id=state.sample_id,
766
+ sample_metadata=state.metadata,
766
767
  )
767
768
  transcript()._event(
768
769
  ScoreEvent(
769
- score=score_result, target=sample.target
770
+ score=score, target=sample.target
770
771
  )
771
772
  )
772
- results[scorer_name] = sample_score
773
-
774
- # add scores returned by solvers
775
- if state.scores is not None:
776
- for name, score in state.scores.items():
777
- results[name] = SampleScore(
778
- score=score,
779
- sample_id=state.sample_id,
780
- sample_metadata=state.metadata,
781
- )
782
- transcript()._event(
783
- ScoreEvent(score=score, target=sample.target)
784
- )
785
773
 
786
- # propagate results into scores
787
- state.scores = {k: v.score for k, v in results.items()}
774
+ # propagate results into scores
775
+ state.scores = {k: v.score for k, v in results.items()}
788
776
 
789
777
  except anyio.get_cancelled_exc_class():
790
778
  if active.interrupt_action:
@@ -798,17 +786,7 @@ async def task_run_sample(
798
786
  raise
799
787
 
800
788
  except BaseException as ex:
801
- # note timeout
802
- if isinstance(ex, TimeoutError):
803
- transcript()._event(
804
- SampleLimitEvent(
805
- type="time",
806
- message=f"Unable to score sample due to exceeded time limit ({time_limit:,} seconds)",
807
- limit=time_limit,
808
- )
809
- )
810
-
811
- # handle error (this will throw if we've exceeded the limit)
789
+ # handle error
812
790
  error, raise_error = handle_error(ex)
813
791
 
814
792
  except Exception as ex:
@@ -64,7 +64,7 @@ class Task:
64
64
  time_limit: int | None = None,
65
65
  working_limit: int | None = None,
66
66
  name: str | None = None,
67
- version: int = 0,
67
+ version: int | str = 0,
68
68
  metadata: dict[str, Any] | None = None,
69
69
  **kwargs: Unpack[TaskDeprecatedArgs],
70
70
  ) -> None:
@@ -136,7 +136,7 @@ def current_async_backend() -> Literal["asyncio", "trio"] | None:
136
136
 
137
137
 
138
138
  def configured_async_backend() -> Literal["asyncio", "trio"]:
139
- backend = os.environ.get("INSPECT_ASYNC_BACKEND", "asyncio").lower()
139
+ backend = os.environ.get("INSPECT_ASYNC_BACKEND", "asyncio").lower() or "asyncio"
140
140
  return _validate_backend(backend)
141
141
 
142
142
 
@@ -1,9 +1,14 @@
1
1
  from typing import Literal, Union
2
2
 
3
- from pydantic import BaseModel, Field
3
+ from pydantic import BaseModel, Field, JsonValue
4
4
 
5
5
 
6
- class ContentText(BaseModel):
6
+ class ContentBase(BaseModel):
7
+ internal: JsonValue | None = Field(default=None)
8
+ """Model provider specific payload - typically used to aid transformation back to model types."""
9
+
10
+
11
+ class ContentText(ContentBase):
7
12
  """Text content."""
8
13
 
9
14
  type: Literal["text"] = Field(default="text")
@@ -16,7 +21,7 @@ class ContentText(BaseModel):
16
21
  """Was this a refusal message?"""
17
22
 
18
23
 
19
- class ContentReasoning(BaseModel):
24
+ class ContentReasoning(ContentBase):
20
25
  """Reasoning content.
21
26
 
22
27
  See the specification for [thinking blocks](https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#understanding-thinking-blocks) for Claude models.
@@ -35,7 +40,7 @@ class ContentReasoning(BaseModel):
35
40
  """Indicates that the explicit content of this reasoning block has been redacted."""
36
41
 
37
42
 
38
- class ContentImage(BaseModel):
43
+ class ContentImage(ContentBase):
39
44
  """Image content."""
40
45
 
41
46
  type: Literal["image"] = Field(default="image")
@@ -51,7 +56,7 @@ class ContentImage(BaseModel):
51
56
  """
52
57
 
53
58
 
54
- class ContentAudio(BaseModel):
59
+ class ContentAudio(ContentBase):
55
60
  """Audio content."""
56
61
 
57
62
  type: Literal["audio"] = Field(default="audio")
@@ -64,7 +69,7 @@ class ContentAudio(BaseModel):
64
69
  """Format of audio data ('mp3' or 'wav')"""
65
70
 
66
71
 
67
- class ContentVideo(BaseModel):
72
+ class ContentVideo(ContentBase):
68
73
  """Video content."""
69
74
 
70
75
  type: Literal["video"] = Field(default="video")
@@ -1,6 +1,6 @@
1
1
  import anyio
2
2
 
3
- from .working import check_sample_working_limit
3
+ from inspect_ai.util._limit import check_working_limit
4
4
 
5
5
 
6
6
  def check_sample_interrupt() -> None:
@@ -12,4 +12,4 @@ def check_sample_interrupt() -> None:
12
12
  raise anyio.get_cancelled_exc_class()
13
13
 
14
14
  # check for working_limit
15
- check_sample_working_limit()
15
+ check_working_limit()
inspect_ai/_util/text.py CHANGED
@@ -1,12 +1,19 @@
1
1
  import random
2
2
  import re
3
3
  import string
4
+ import textwrap
4
5
  from logging import getLogger
5
6
  from typing import List, NamedTuple
6
7
 
7
8
  logger = getLogger(__name__)
8
9
 
9
10
 
11
+ def truncate_text(text: str, max_length: int) -> str:
12
+ if len(text) <= max_length:
13
+ return text
14
+ return textwrap.shorten(text, width=max_length, placeholder="...")
15
+
16
+
10
17
  def strip_punctuation(s: str) -> str:
11
18
  return s.strip(string.whitespace + string.punctuation)
12
19