inspect-ai 0.3.92__py3-none-any.whl → 0.3.94__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. inspect_ai/_cli/eval.py +27 -0
  2. inspect_ai/_display/textual/widgets/samples.py +3 -3
  3. inspect_ai/_display/textual/widgets/transcript.py +3 -29
  4. inspect_ai/_eval/eval.py +19 -2
  5. inspect_ai/_eval/evalset.py +4 -1
  6. inspect_ai/_eval/run.py +41 -0
  7. inspect_ai/_eval/task/generate.py +38 -44
  8. inspect_ai/_eval/task/log.py +26 -28
  9. inspect_ai/_eval/task/run.py +23 -27
  10. inspect_ai/_util/answer.py +26 -0
  11. inspect_ai/_util/constants.py +0 -1
  12. inspect_ai/_util/local_server.py +398 -0
  13. inspect_ai/_util/working.py +10 -4
  14. inspect_ai/_view/www/dist/assets/index.css +173 -159
  15. inspect_ai/_view/www/dist/assets/index.js +1417 -1142
  16. inspect_ai/_view/www/log-schema.json +379 -3
  17. inspect_ai/_view/www/package.json +1 -1
  18. inspect_ai/_view/www/src/@types/log.d.ts +93 -14
  19. inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +2 -2
  20. inspect_ai/_view/www/src/app/content/MetaDataView.module.css +1 -1
  21. inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +1 -1
  22. inspect_ai/_view/www/src/app/content/RenderedContent.tsx +1 -1
  23. inspect_ai/_view/www/src/app/log-view/LogView.tsx +11 -0
  24. inspect_ai/_view/www/src/app/log-view/tabs/InfoTab.tsx +2 -9
  25. inspect_ai/_view/www/src/app/log-view/tabs/ModelsTab.tsx +51 -0
  26. inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.module.css +6 -0
  27. inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.tsx +143 -0
  28. inspect_ai/_view/www/src/app/plan/ModelCard.tsx +1 -2
  29. inspect_ai/_view/www/src/app/plan/PlanCard.tsx +29 -7
  30. inspect_ai/_view/www/src/app/plan/PlanDetailView.module.css +1 -1
  31. inspect_ai/_view/www/src/app/plan/PlanDetailView.tsx +1 -198
  32. inspect_ai/_view/www/src/app/samples/descriptor/score/NumericScoreDescriptor.tsx +2 -1
  33. inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
  34. inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
  35. inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
  36. inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
  37. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
  38. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
  39. inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
  40. inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
  41. inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
  42. inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
  43. inspect_ai/_view/www/src/app/usage/ModelUsagePanel.tsx +3 -2
  44. inspect_ai/_view/www/src/app/usage/TokenTable.module.css +4 -1
  45. inspect_ai/_view/www/src/app/usage/TokenTable.tsx +2 -2
  46. inspect_ai/_view/www/src/app/usage/UsageCard.module.css +8 -3
  47. inspect_ai/_view/www/src/app/usage/UsageCard.tsx +1 -35
  48. inspect_ai/_view/www/src/components/Card.css +0 -1
  49. inspect_ai/_view/www/src/constants.ts +2 -0
  50. inspect_ai/_view/www/src/utils/numeric.ts +17 -0
  51. inspect_ai/agent/_agent.py +3 -3
  52. inspect_ai/agent/_as_solver.py +22 -12
  53. inspect_ai/agent/_as_tool.py +20 -6
  54. inspect_ai/agent/_handoff.py +12 -1
  55. inspect_ai/agent/_react.py +4 -3
  56. inspect_ai/agent/_run.py +16 -3
  57. inspect_ai/agent/_types.py +9 -0
  58. inspect_ai/dataset/_dataset.py +6 -3
  59. inspect_ai/log/__init__.py +14 -0
  60. inspect_ai/log/_convert.py +4 -9
  61. inspect_ai/log/_file.py +56 -0
  62. inspect_ai/log/_log.py +99 -0
  63. inspect_ai/log/_recorders/__init__.py +2 -0
  64. inspect_ai/log/_recorders/buffer/database.py +12 -11
  65. inspect_ai/log/_recorders/buffer/filestore.py +2 -2
  66. inspect_ai/log/_recorders/buffer/types.py +2 -2
  67. inspect_ai/log/_recorders/eval.py +20 -65
  68. inspect_ai/log/_recorders/file.py +28 -6
  69. inspect_ai/log/_recorders/recorder.py +7 -0
  70. inspect_ai/log/_recorders/types.py +1 -23
  71. inspect_ai/log/_samples.py +14 -25
  72. inspect_ai/log/_transcript.py +84 -36
  73. inspect_ai/log/_tree.py +118 -0
  74. inspect_ai/log/_util.py +52 -0
  75. inspect_ai/model/__init__.py +5 -1
  76. inspect_ai/model/_call_tools.py +72 -44
  77. inspect_ai/model/_generate_config.py +14 -8
  78. inspect_ai/model/_model.py +66 -88
  79. inspect_ai/model/_model_output.py +25 -0
  80. inspect_ai/model/_openai.py +2 -0
  81. inspect_ai/model/_providers/anthropic.py +13 -23
  82. inspect_ai/model/_providers/hf.py +27 -1
  83. inspect_ai/model/_providers/openai_o1.py +8 -2
  84. inspect_ai/model/_providers/providers.py +18 -4
  85. inspect_ai/model/_providers/sglang.py +247 -0
  86. inspect_ai/model/_providers/vllm.py +211 -400
  87. inspect_ai/scorer/_choice.py +1 -2
  88. inspect_ai/solver/__init__.py +7 -2
  89. inspect_ai/solver/_basic_agent.py +3 -10
  90. inspect_ai/solver/_chain.py +1 -1
  91. inspect_ai/solver/_fork.py +1 -1
  92. inspect_ai/solver/_multiple_choice.py +5 -22
  93. inspect_ai/solver/_plan.py +2 -2
  94. inspect_ai/solver/_task_state.py +26 -88
  95. inspect_ai/solver/_transcript.py +6 -7
  96. inspect_ai/tool/_json_rpc_helpers.py +45 -17
  97. inspect_ai/tool/_mcp/_mcp.py +8 -5
  98. inspect_ai/tool/_mcp/_sandbox.py +8 -2
  99. inspect_ai/tool/_mcp/server.py +3 -1
  100. inspect_ai/tool/_tool_call.py +4 -1
  101. inspect_ai/tool/_tool_support_helpers.py +51 -12
  102. inspect_ai/tool/_tools/_bash_session.py +190 -68
  103. inspect_ai/tool/_tools/_computer/_computer.py +25 -1
  104. inspect_ai/tool/_tools/_execute.py +4 -1
  105. inspect_ai/tool/_tools/_text_editor.py +4 -3
  106. inspect_ai/tool/_tools/_web_browser/_web_browser.py +10 -3
  107. inspect_ai/util/__init__.py +16 -0
  108. inspect_ai/util/_anyio.py +11 -0
  109. inspect_ai/util/_collect.py +50 -0
  110. inspect_ai/util/_limit.py +393 -0
  111. inspect_ai/util/_limited_conversation.py +57 -0
  112. inspect_ai/util/_span.py +58 -0
  113. inspect_ai/util/_subtask.py +27 -42
  114. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/METADATA +1 -1
  115. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/RECORD +120 -134
  116. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/WHEEL +1 -1
  117. inspect_ai/_display/core/group.py +0 -79
  118. inspect_ai/solver/_limit.py +0 -39
  119. inspect_ai/tool/_tools/_computer/_resources/Dockerfile +0 -102
  120. inspect_ai/tool/_tools/_computer/_resources/README.md +0 -30
  121. inspect_ai/tool/_tools/_computer/_resources/entrypoint/entrypoint.sh +0 -18
  122. inspect_ai/tool/_tools/_computer/_resources/entrypoint/novnc_startup.sh +0 -20
  123. inspect_ai/tool/_tools/_computer/_resources/entrypoint/x11vnc_startup.sh +0 -48
  124. inspect_ai/tool/_tools/_computer/_resources/entrypoint/xfce_startup.sh +0 -13
  125. inspect_ai/tool/_tools/_computer/_resources/entrypoint/xvfb_startup.sh +0 -48
  126. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/globalStorage/state.vscdb +0 -0
  127. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/settings.json +0 -9
  128. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-panel.xml +0 -61
  129. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +0 -10
  130. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfwm4.xml +0 -91
  131. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +0 -10
  132. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Terminal.desktop +0 -10
  133. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +0 -10
  134. inspect_ai/tool/_tools/_computer/_resources/tool/.pylintrc +0 -8
  135. inspect_ai/tool/_tools/_computer/_resources/tool/.vscode/settings.json +0 -12
  136. inspect_ai/tool/_tools/_computer/_resources/tool/_args.py +0 -78
  137. inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +0 -22
  138. inspect_ai/tool/_tools/_computer/_resources/tool/_logger.py +0 -22
  139. inspect_ai/tool/_tools/_computer/_resources/tool/_run.py +0 -42
  140. inspect_ai/tool/_tools/_computer/_resources/tool/_tool_result.py +0 -33
  141. inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +0 -341
  142. inspect_ai/tool/_tools/_computer/_resources/tool/computer_tool.py +0 -141
  143. inspect_ai/tool/_tools/_computer/_resources/tool/pyproject.toml +0 -65
  144. inspect_ai/tool/_tools/_computer/_resources/tool/requirements.txt +0 -0
  145. inspect_ai/tool/_tools/_computer/test_args.py +0 -151
  146. /inspect_ai/{tool/_tools/_computer/_resources/tool/__init__.py → _view/www/src/app/log-view/tabs/ModelsTab.module.css} +0 -0
  147. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/entry_points.txt +0 -0
  148. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/licenses/LICENSE +0 -0
  149. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,8 @@ from inspect_ai.scorer._score import score
13
13
  from inspect_ai.solver._chain import chain
14
14
  from inspect_ai.tool._tool import Tool, ToolResult, tool
15
15
  from inspect_ai.tool._tool_with import tool_with
16
+ from inspect_ai.util._limit import token_limit as create_token_limit
16
17
 
17
- from ._limit import SampleLimitExceededError
18
18
  from ._prompt import system_message
19
19
  from ._solver import Generate, Solver, solver
20
20
  from ._task_state import TaskState
@@ -172,14 +172,11 @@ def basic_agent(
172
172
  # (if there is no message_limit then default to 50)
173
173
  state.message_limit = message_limit or state.message_limit or 50
174
174
 
175
- # resolve token limit
176
- state.token_limit = token_limit or state.token_limit
177
-
178
175
  # track attempts
179
176
  attempts = 0
180
177
 
181
- try:
182
- # main loop (state.completed checks message_limit and token_limit)
178
+ with create_token_limit(token_limit):
179
+ # main loop
183
180
  while not state.completed:
184
181
  # generate output and append assistant message
185
182
  state.output = await get_model().generate(
@@ -247,10 +244,6 @@ def basic_agent(
247
244
  else:
248
245
  state.messages.append(ChatMessageUser(content=continue_message))
249
246
 
250
- # propagate current state along with sample limit exceeded
251
- except SampleLimitExceededError as ex:
252
- raise ex.with_state(state)
253
-
254
247
  return state
255
248
 
256
249
  return solve
@@ -82,7 +82,7 @@ class Chain(Sequence[Solver], Solver):
82
82
  from ._transcript import solver_transcript
83
83
 
84
84
  for slv in self._solvers:
85
- with solver_transcript(slv, state) as st:
85
+ async with solver_transcript(slv, state) as st:
86
86
  state = await slv(state, generate)
87
87
  st.complete(state)
88
88
  if state.completed:
@@ -73,7 +73,7 @@ async def solver_subtask(state: TaskState, solver: Solver) -> TaskState:
73
73
  @subtask(name=name, store=state.store, type="fork", input=input) # type: ignore
74
74
  async def solve() -> TaskState:
75
75
  if not isinstance(solver, Chain):
76
- with solver_transcript(solver, state) as st:
76
+ async with solver_transcript(solver, state) as st:
77
77
  new_state = await solver(state, generate)
78
78
  st.complete(new_state)
79
79
  return new_state
@@ -6,6 +6,7 @@ from typing import Match, TypedDict
6
6
 
7
7
  from typing_extensions import Unpack
8
8
 
9
+ from inspect_ai._util.answer import answer_character, answer_index
9
10
  from inspect_ai._util.logger import warn_once
10
11
  from inspect_ai.util import resource
11
12
 
@@ -64,31 +65,13 @@ def answer_options(choices: Choices) -> str:
64
65
  indexes = list(range(len(choices)))
65
66
 
66
67
  return "\n".join(
67
- [f"{chr(65 + i)}) {choices[j].value}" for i, j in enumerate(indexes)]
68
+ [f"{answer_character(i)}) {choices[j].value}" for i, j in enumerate(indexes)]
68
69
  )
69
70
 
70
71
 
71
- def answer_character(index: int) -> str:
72
- r"""
73
- Helper to go from array index to char, for example:
74
-
75
- 0 -> 'A', 1 -> 'B', etc
76
- """
77
- return chr(ord("A") + index)
78
-
79
-
80
- def answer_index(char: str) -> int:
81
- r"""
82
- Helper to go from char to array index, for example:
83
-
84
- 'A' -> 0, 'B' -> 1, etc
85
- """
86
- return ord(char.upper()) - ord("A")
87
-
88
-
89
72
  def prompt(question: str, choices: Choices, template: str) -> str:
90
73
  choices_text = answer_options(choices)
91
- letters = ",".join(chr(65 + i) for i in range(len(choices)))
74
+ letters = ",".join(answer_character(i) for i in range(len(choices)))
92
75
 
93
76
  return template.format(
94
77
  choices=choices_text,
@@ -112,7 +95,7 @@ def parse_answers(state: TaskState) -> Match[str] | None:
112
95
  # In this case, we're looking for a single line which contains the expected
113
96
  # ANSWER: B,C string with only whitespace after it
114
97
  match = re.search(
115
- r"(?i)^ANSWER\s*:\s*([A-Za-z ,]+)\s*(?:$|\n)",
98
+ r"(?i)^ANSWER\s*:\s*([A-Za-z\d ,]+)\s*(?:$|\n)",
116
99
  state.output.completion,
117
100
  flags=re.MULTILINE,
118
101
  )
@@ -121,7 +104,7 @@ def parse_answers(state: TaskState) -> Match[str] | None:
121
104
  # version for backward compatibility
122
105
  if match is None:
123
106
  return re.search(
124
- r"(?i)ANSWER\s*:\s*([A-Za-z ,]+)(?:[^\w]|\n|$)", state.output.completion
107
+ r"(?i)ANSWER\s*:\s*([A-Za-z\d ,]+)(?:[^\w]|\n|$)", state.output.completion
125
108
  )
126
109
  else:
127
110
  return match
@@ -102,7 +102,7 @@ class Plan(Solver):
102
102
  # execute steps
103
103
  for index, solver in enumerate(self.steps):
104
104
  # run solver
105
- with solver_transcript(solver, state) as st:
105
+ async with solver_transcript(solver, state) as st:
106
106
  state = await solver(state, generate)
107
107
  st.complete(state)
108
108
 
@@ -113,7 +113,7 @@ class Plan(Solver):
113
113
 
114
114
  # execute finish
115
115
  if self.finish:
116
- with solver_transcript(self.finish, state) as st:
116
+ async with solver_transcript(self.finish, state) as st:
117
117
  state = await self.finish(state, generate)
118
118
  st.complete(state)
119
119
  check_sample_interrupt()
@@ -2,9 +2,8 @@ from collections.abc import Sequence
2
2
  from contextvars import ContextVar
3
3
  from copy import deepcopy
4
4
  from dataclasses import dataclass
5
- from itertools import tee
6
5
  from random import Random
7
- from typing import Any, Iterable, SupportsIndex, Type, Union, cast, overload
6
+ from typing import Any, Type, Union, cast, overload
8
7
 
9
8
  from pydantic_core import to_jsonable_python
10
9
  from shortuuid import uuid
@@ -18,12 +17,18 @@ from inspect_ai.model import (
18
17
  ModelOutput,
19
18
  )
20
19
  from inspect_ai.model._call_tools import tools_info
21
- from inspect_ai.model._chat_message import ChatMessageBase
22
20
  from inspect_ai.model._model import sample_total_tokens
23
21
  from inspect_ai.scorer._metric import Score
24
22
  from inspect_ai.scorer._target import Target
25
23
  from inspect_ai.tool import Tool, ToolChoice
26
24
  from inspect_ai.tool._tool_def import ToolDef
25
+ from inspect_ai.util._limit import (
26
+ check_message_limit,
27
+ check_token_limit,
28
+ )
29
+ from inspect_ai.util._limit import message_limit as create_message_limit
30
+ from inspect_ai.util._limit import token_limit as create_token_limit
31
+ from inspect_ai.util._limited_conversation import ChatMessageList
27
32
  from inspect_ai.util._store import Store, store_jsonable
28
33
  from inspect_ai.util._store_model import SMT
29
34
 
@@ -159,11 +164,11 @@ class TaskState:
159
164
  self._input = input
160
165
  self._target = target
161
166
  self._metadata = metadata
162
- self._messages: list[ChatMessage] = ChatMessageList(messages, self)
167
+ self._messages: list[ChatMessage] = ChatMessageList(messages)
163
168
  self._tools: list[Tool] = []
164
169
  self._output = output if output else ModelOutput(model=str(model))
165
- self._message_limit = message_limit
166
- self._token_limit = token_limit
170
+ self._message_limit = create_message_limit(message_limit)
171
+ self._token_limit = create_token_limit(token_limit)
167
172
  self._completed = completed
168
173
  self._store = Store()
169
174
  self._uuid = uuid()
@@ -254,7 +259,7 @@ class TaskState:
254
259
 
255
260
  @messages.setter
256
261
  def messages(self, messages: list[ChatMessage]) -> None:
257
- self._messages = ChatMessageList(messages, self)
262
+ self._messages = ChatMessageList(messages)
258
263
 
259
264
  @property
260
265
  def output(self) -> ModelOutput:
@@ -302,12 +307,16 @@ class TaskState:
302
307
  @property
303
308
  def message_limit(self) -> int | None:
304
309
  """Limit on total messages allowed per conversation."""
305
- return self._message_limit
310
+ return self._message_limit.limit
306
311
 
307
312
  @message_limit.setter
308
313
  def message_limit(self, messages: int | None) -> None:
309
- """Set limit on total messages allowed per conversation."""
310
- self._message_limit = messages
314
+ """Set limit on total messages allowed per conversation.
315
+
316
+ Also checks whether the current message count exceeds the new limit.
317
+ """
318
+ self._message_limit.limit = messages
319
+ check_message_limit(len(self.messages), raise_for_equal=False)
311
320
 
312
321
  from inspect_ai.log._samples import set_active_sample_message_limit
313
322
 
@@ -316,12 +325,16 @@ class TaskState:
316
325
  @property
317
326
  def token_limit(self) -> int | None:
318
327
  """Limit on total tokens allowed per conversation."""
319
- return self._token_limit
328
+ return self._token_limit.limit
320
329
 
321
330
  @token_limit.setter
322
331
  def token_limit(self, tokens: int | None) -> None:
323
- """Set limit on total tokens allowed per conversation."""
324
- self._token_limit = tokens
332
+ """Set limit on total tokens allowed per conversation.
333
+
334
+ Also checks whether the current token usage exceeds the new limit.
335
+ """
336
+ self._token_limit.limit = tokens
337
+ check_token_limit()
325
338
 
326
339
  from inspect_ai.log._samples import set_active_sample_token_limit
327
340
 
@@ -340,24 +353,11 @@ class TaskState:
340
353
  """
341
354
  from inspect_ai.log._samples import set_active_sample_total_messages
342
355
 
343
- from ._limit import SampleLimitExceededError
344
-
345
356
  # update messages
346
357
  set_active_sample_total_messages(len(self.messages))
347
358
 
348
359
  if self._completed:
349
360
  return True
350
- elif self.message_limit and len(self.messages) >= self.message_limit:
351
- raise SampleLimitExceededError(
352
- "message",
353
- value=len(self.messages),
354
- limit=self.message_limit,
355
- state=self,
356
- )
357
- elif self.token_limit and self.token_usage >= self.token_limit:
358
- raise SampleLimitExceededError(
359
- "token", value=self.token_usage, limit=self.token_limit, state=self
360
- )
361
361
  else:
362
362
  check_sample_interrupt()
363
363
  return self._completed
@@ -445,65 +445,3 @@ def state_jsonable(state: TaskState | None = None) -> dict[str, Any]:
445
445
  def sample_jsonable(sample: Sample) -> dict[str, Any]:
446
446
  jsonable = to_jsonable_python(sample, exclude_none=True, fallback=lambda _x: None)
447
447
  return cast(dict[str, Any], deepcopy(jsonable))
448
-
449
-
450
- class ChatMessageList(list[ChatMessage]):
451
- def __init__(self, iterable: Iterable[ChatMessage], parent_state: TaskState):
452
- self.parent_state = parent_state
453
- items, length = self._iterable_length(iterable)
454
- self._check_size(length)
455
- super().__init__(items)
456
-
457
- def _check_size(self, additional_items: int = 1) -> None:
458
- from inspect_ai.log._samples import active_sample_message_limit
459
-
460
- from ._limit import SampleLimitExceededError
461
-
462
- messages_limit = active_sample_message_limit()
463
- if messages_limit is not None:
464
- messages = len(self) + additional_items
465
- if messages > messages_limit:
466
- raise SampleLimitExceededError(
467
- "message",
468
- value=messages,
469
- limit=messages_limit,
470
- message=None,
471
- state=self.parent_state,
472
- )
473
-
474
- def append(self, item: ChatMessage) -> None:
475
- self._check_size()
476
- super().append(item)
477
-
478
- def extend(self, items: Iterable[ChatMessage]) -> None:
479
- items, length = self._iterable_length(items)
480
- self._check_size(length)
481
- super().extend(items)
482
-
483
- def insert(self, index: SupportsIndex, item: ChatMessage) -> None:
484
- self._check_size()
485
- super().insert(index, item)
486
-
487
- @overload
488
- def __setitem__(self, index: SupportsIndex, item: ChatMessage) -> None: ...
489
-
490
- @overload
491
- def __setitem__(self, index: slice, item: Iterable[ChatMessage]) -> None: ...
492
-
493
- def __setitem__(
494
- self, index: SupportsIndex | slice, item: ChatMessage | Iterable[ChatMessage]
495
- ) -> None:
496
- if isinstance(index, slice) and not isinstance(item, ChatMessageBase):
497
- item, length = self._iterable_length(item)
498
- size_change = length - len(self[index])
499
- if size_change > 0:
500
- self._check_size(size_change)
501
-
502
- super().__setitem__(index, item) # type: ignore[assignment,index]
503
-
504
- def _iterable_length(
505
- self, items: Iterable[ChatMessage]
506
- ) -> tuple[Iterable[ChatMessage], int]:
507
- items, counter = tee(items)
508
- length = sum(1 for _ in counter)
509
- return items, length
@@ -1,8 +1,9 @@
1
1
  import contextlib
2
- from typing import Iterator
2
+ from typing import AsyncIterator
3
3
 
4
4
  from inspect_ai._util.json import json_changes
5
5
  from inspect_ai._util.registry import registry_log_name
6
+ from inspect_ai.util._span import span
6
7
 
7
8
  from ._solver import Solver
8
9
  from ._task_state import TaskState, state_jsonable
@@ -22,12 +23,10 @@ class SolverTranscript:
22
23
  transcript()._event(StateEvent(changes=changes))
23
24
 
24
25
 
25
- @contextlib.contextmanager
26
- def solver_transcript(
26
+ @contextlib.asynccontextmanager
27
+ async def solver_transcript(
27
28
  solver: Solver, state: TaskState, name: str | None = None
28
- ) -> Iterator[SolverTranscript]:
29
- from inspect_ai.log._transcript import transcript
30
-
29
+ ) -> AsyncIterator[SolverTranscript]:
31
30
  name = registry_log_name(name or solver)
32
- with transcript().step(name=name, type="solver"):
31
+ async with span(name=name, type="solver"):
33
32
  yield SolverTranscript(name, state)
@@ -4,7 +4,7 @@ from typing import Literal, Protocol, Type, TypeAlias, TypeVar
4
4
 
5
5
  from pydantic import BaseModel, RootModel
6
6
 
7
- from inspect_ai.tool._tool import ToolError
7
+ from inspect_ai.tool._tool import ToolError, ToolParsingError
8
8
 
9
9
 
10
10
  class JSONRPCResponseBase(BaseModel):
@@ -70,6 +70,7 @@ async def exec_scalar_request(
70
70
  params: JSONRPCParamsType,
71
71
  result_type: Type[ScalarT],
72
72
  transport: JSONRPCTransport,
73
+ server_error_mapper: JSONRPCServerErrorMapper,
73
74
  ) -> ScalarT:
74
75
  """
75
76
  Execute a JSON-RPC command expecting a scalar result.
@@ -79,6 +80,7 @@ async def exec_scalar_request(
79
80
  params (JSONRPCParamsType): The parameters for the JSON-RPC method.
80
81
  result_type (Type[ScalarT]): The scalar type (str, int, float, bool, None) to validate the result against.
81
82
  transport (JSONRPCTransport): The transport callable to use for the RPC communication.
83
+ server_error_mapper (JSONRPCServerErrorMapper): A callable to map server specific JSON-RPC errors to exceptions.
82
84
 
83
85
  Returns:
84
86
  ScalarT: The scalar result of the JSON-RPC call.
@@ -88,7 +90,12 @@ async def exec_scalar_request(
88
90
  ToolParsingError: If the JSON-RPC response contains a specific error code indicating a parsing error.
89
91
  ValueError: If the result is not of the expected scalar type.
90
92
  """
91
- rpc_result = await _exec_request(method=method, params=params, transport=transport)
93
+ rpc_result = await _exec_request(
94
+ method=method,
95
+ params=params,
96
+ transport=transport,
97
+ server_error_mapper=server_error_mapper,
98
+ )
92
99
  if (result_type is type(None) and rpc_result is not None) or not isinstance(
93
100
  rpc_result, result_type
94
101
  ):
@@ -101,6 +108,7 @@ async def exec_model_request(
101
108
  params: JSONRPCParamsType,
102
109
  result_type: Type[BaseModelT],
103
110
  transport: JSONRPCTransport,
111
+ server_error_mapper: JSONRPCServerErrorMapper | None = None,
104
112
  ) -> BaseModelT:
105
113
  """
106
114
  Execute a JSON-RPC command to a sandbox environment expecting a model result.
@@ -110,6 +118,7 @@ async def exec_model_request(
110
118
  params (JSONRPCParamsType): The parameters for the JSON-RPC method.
111
119
  result_type (Type[BaseModelT]): The Pydantic model class to validate and parse the result.
112
120
  transport (JSONRPCTransport): The transport callable to use for the RPC communication.
121
+ server_error_mapper (JSONRPCServerErrorMapper): A callable to map server specific JSON-RPC errors to exceptions.
113
122
 
114
123
  Returns:
115
124
  BaseModelT: The parsed and validated result of the JSON-RPC call.
@@ -119,7 +128,12 @@ async def exec_model_request(
119
128
  ToolParsingError: If the JSON-RPC response contains a specific error code indicating a parsing error.
120
129
  ValueError: If the result cannot be validated against the provided model class.
121
130
  """
122
- rpc_result = await _exec_request(method=method, params=params, transport=transport)
131
+ rpc_result = await _exec_request(
132
+ method=method,
133
+ params=params,
134
+ transport=transport,
135
+ server_error_mapper=server_error_mapper,
136
+ )
123
137
  return result_type.model_validate(rpc_result, strict=True)
124
138
 
125
139
 
@@ -161,6 +175,7 @@ async def _exec_request(
161
175
  method: str,
162
176
  params: JSONRPCParamsType,
163
177
  transport: JSONRPCTransport,
178
+ server_error_mapper: JSONRPCServerErrorMapper | None = None,
164
179
  ) -> object:
165
180
  """Execute a request using the provided transport mechanism."""
166
181
  return parse_json_rpc_response(
@@ -171,6 +186,7 @@ async def _exec_request(
171
186
  ),
172
187
  method,
173
188
  params,
189
+ server_error_mapper,
174
190
  )
175
191
 
176
192
 
@@ -178,15 +194,16 @@ def parse_json_rpc_response(
178
194
  response_str: str,
179
195
  method: str,
180
196
  params: JSONRPCParamsType,
197
+ server_error_mapper: JSONRPCServerErrorMapper | None = None,
181
198
  ) -> object:
182
199
  """Validates the JSON RPC response and returns the result or raises a proper Inspect error."""
183
200
  match JSONRPCResponse.model_validate_json(response_str).root:
184
201
  case JSONRPCSuccessResponse(result=rpc_result):
185
202
  return rpc_result
186
- case JSONRPCErrorResponse(
187
- error=JSONRPCError(code=code, message=message, data=_)
188
- ):
189
- raise exception_for_rpc_response_error(code, message, method, params)
203
+ case JSONRPCErrorResponse(error=JSONRPCError(code=code, message=message)):
204
+ raise exception_for_rpc_response_error(
205
+ code, message, method, params, server_error_mapper
206
+ )
190
207
  case _:
191
208
  raise ValueError(
192
209
  f"Unexpected JSON RPC response to request {_rpc_call_description(method, params)}: {response_str}"
@@ -220,16 +237,17 @@ def exception_for_rpc_response_error(
220
237
  if server_error_mapper
221
238
  else ToolError(message)
222
239
  )
240
+ elif code == -32602: # (Invalid params)
241
+ # Even though the Inspect side does validation, it can't possibly be
242
+ # complete - especially for tools that have dynamic action dependant
243
+ # rules for optional/required params.
244
+ return ToolParsingError(message)
223
245
  elif code == -32603:
224
246
  return ToolError(message)
225
247
  else:
226
248
  # -32600 (Invalid Request)
227
249
  # If we sent a bogus request, it's 100% a code bug.
228
250
  # -32601 (Method not found)
229
- # -32602 (Invalid params)
230
- # These shouldn't be possible since Inspect did validation prior to
231
- # making the tool call. Because of that, these errors should not make
232
- # it back to the model, so choose RuntimeError.
233
251
  # -32700 (Parse error)
234
252
  # shouldn't be seen in this flow since we're processing responses, and
235
253
  # this is a request oriented error.
@@ -276,10 +294,20 @@ def create_json_rpc_request(
276
294
  is_notification: bool,
277
295
  ) -> str:
278
296
  return json.dumps(
279
- {
280
- "jsonrpc": "2.0",
281
- "method": method,
282
- **({"params": params} if params else {}),
283
- **({"id": next(id_generator)} if not is_notification else {}),
284
- }
297
+ remove_none_values(
298
+ {
299
+ "jsonrpc": "2.0",
300
+ "method": method,
301
+ **({"params": params} if params else {}),
302
+ **({"id": next(id_generator)} if not is_notification else {}),
303
+ }
304
+ )
285
305
  )
306
+
307
+
308
+ def remove_none_values(obj: object) -> object:
309
+ if isinstance(obj, dict):
310
+ return {k: remove_none_values(v) for k, v in obj.items() if v is not None}
311
+ elif isinstance(obj, list):
312
+ return [remove_none_values(item) for item in obj if item is not None]
313
+ return obj
@@ -61,16 +61,17 @@ class MCPServerImpl(MCPServer):
61
61
  ) -> list[Tool]:
62
62
  return await self._task_session()._list_tools(tools)
63
63
 
64
- # create a separate MCPServer session per async task
65
- _task_sessions: dict[int, "MCPServerSession"] = {}
64
+ # create a separate MCPServer session per async task / server name
65
+ _task_sessions: dict[str, "MCPServerSession"] = {}
66
66
 
67
67
  def _task_session(self) -> "MCPServerSession":
68
68
  task_id = anyio.get_current_task().id
69
- if task_id not in self._task_sessions:
70
- MCPServerImpl._task_sessions[task_id] = MCPServerSession(
69
+ session_key = f"{task_id}_{self._name}"
70
+ if session_key not in self._task_sessions:
71
+ MCPServerImpl._task_sessions[session_key] = MCPServerSession(
71
72
  self._client, name=self._name, events=self._events
72
73
  )
73
- return MCPServerImpl._task_sessions[task_id]
74
+ return MCPServerImpl._task_sessions[session_key]
74
75
 
75
76
 
76
77
  class MCPServerSession(MCPServer):
@@ -259,6 +260,7 @@ def create_server_sandbox(
259
260
  cwd: str | Path | None = None,
260
261
  env: dict[str, str] | None = None,
261
262
  sandbox: str | None = None,
263
+ timeout: int | None = None,
262
264
  ) -> MCPServer:
263
265
  # TODO: Confirm the lifetime concepts. By the time a request makes it to the
264
266
  # sandbox, it's going to need both a session id and a server "name".
@@ -272,6 +274,7 @@ def create_server_sandbox(
272
274
  env=env,
273
275
  ),
274
276
  sandbox_name=sandbox,
277
+ timeout=timeout,
275
278
  ),
276
279
  name=name,
277
280
  events=False,
@@ -11,7 +11,7 @@ from inspect_ai.tool._tool_support_helpers import (
11
11
  exec_model_request,
12
12
  exec_notification,
13
13
  exec_scalar_request,
14
- tool_container_sandbox,
14
+ tool_support_sandbox,
15
15
  )
16
16
 
17
17
  from ._context import MCPServerContext
@@ -28,8 +28,10 @@ async def sandbox_client( # type: ignore
28
28
  *,
29
29
  sandbox_name: str | None = None,
30
30
  errlog: TextIO = sys.stderr,
31
+ timeout: int | None = None, # default 180 seconds
31
32
  ) -> MCPServerContext: # type: ignore
32
- sandbox_environment = await tool_container_sandbox(
33
+ timeout = timeout or 180
34
+ (sandbox_environment, _) = await tool_support_sandbox(
33
35
  "mcp support", sandbox_name=sandbox_name
34
36
  )
35
37
 
@@ -49,6 +51,7 @@ async def sandbox_client( # type: ignore
49
51
  method="mcp_launch_server",
50
52
  params={"server_params": server.model_dump()},
51
53
  result_type=int,
54
+ timeout=timeout,
52
55
  )
53
56
 
54
57
  async def stdout_reader() -> None:
@@ -72,6 +75,7 @@ async def sandbox_client( # type: ignore
72
75
  "request": root.model_dump(),
73
76
  },
74
77
  result_type=JSONRPCMessage,
78
+ timeout=timeout,
75
79
  )
76
80
  )
77
81
  elif isinstance(root, JSONRPCNotification):
@@ -82,6 +86,7 @@ async def sandbox_client( # type: ignore
82
86
  "session_id": session_id,
83
87
  "notification": root.model_dump(),
84
88
  },
89
+ timeout=timeout,
85
90
  )
86
91
  else:
87
92
  assert False, f"Unexpected message type {message=}"
@@ -101,4 +106,5 @@ async def sandbox_client( # type: ignore
101
106
  method="mcp_kill_server",
102
107
  params={"session_id": session_id},
103
108
  result_type=type(None),
109
+ timeout=timeout,
104
110
  )
@@ -73,6 +73,7 @@ def mcp_server_sandbox(
73
73
  cwd: str | Path | None = None,
74
74
  env: dict[str, str] | None = None,
75
75
  sandbox: str | None = None,
76
+ timeout: int | None = None,
76
77
  ) -> MCPServer:
77
78
  """MCP Server (Sandbox).
78
79
 
@@ -87,6 +88,7 @@ def mcp_server_sandbox(
87
88
  "SHELL", "TERM", and "USER" for Posix-based systems).
88
89
  cwd: The working directory to use when spawning the process.
89
90
  sandbox: The sandbox to use when spawning the process.
91
+ timeout: Timeout (in seconds) for command.
90
92
 
91
93
  Returns:
92
94
  McpClient: Client for MCP Server
@@ -94,7 +96,7 @@ def mcp_server_sandbox(
94
96
  verfify_mcp_package()
95
97
  from ._mcp import create_server_sandbox
96
98
 
97
- return create_server_sandbox(command, args, cwd, env, sandbox)
99
+ return create_server_sandbox(command, args, cwd, env, sandbox, timeout)
98
100
 
99
101
 
100
102
  def verfify_mcp_package() -> None:
@@ -68,9 +68,12 @@ class ToolCallError:
68
68
  "permission",
69
69
  "file_not_found",
70
70
  "is_a_directory",
71
- "output_limit",
71
+ "limit",
72
72
  "approval",
73
73
  "unknown",
74
+ # Retained for backward compatibility when loading logs created with an older
75
+ # version of inspect.
76
+ "output_limit",
74
77
  ]
75
78
  """Error type."""
76
79