inspect-ai 0.3.82__py3-none-any.whl → 0.3.84__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. inspect_ai/__init__.py +2 -1
  2. inspect_ai/_display/textual/app.py +14 -3
  3. inspect_ai/_display/textual/display.py +4 -0
  4. inspect_ai/_display/textual/widgets/samples.py +9 -3
  5. inspect_ai/_display/textual/widgets/task_detail.py +3 -4
  6. inspect_ai/_display/textual/widgets/tasks.py +17 -1
  7. inspect_ai/_display/textual/widgets/vscode.py +48 -0
  8. inspect_ai/_eval/eval.py +36 -24
  9. inspect_ai/_eval/evalset.py +17 -18
  10. inspect_ai/_eval/loader.py +34 -11
  11. inspect_ai/_eval/run.py +8 -13
  12. inspect_ai/_eval/score.py +13 -3
  13. inspect_ai/_eval/task/generate.py +8 -9
  14. inspect_ai/_eval/task/log.py +2 -0
  15. inspect_ai/_eval/task/task.py +23 -9
  16. inspect_ai/_util/file.py +13 -0
  17. inspect_ai/_util/json.py +2 -1
  18. inspect_ai/_util/registry.py +1 -0
  19. inspect_ai/_util/vscode.py +37 -0
  20. inspect_ai/_view/www/App.css +6 -0
  21. inspect_ai/_view/www/dist/assets/index.css +304 -128
  22. inspect_ai/_view/www/dist/assets/index.js +47495 -27519
  23. inspect_ai/_view/www/log-schema.json +124 -31
  24. inspect_ai/_view/www/package.json +3 -0
  25. inspect_ai/_view/www/src/App.tsx +12 -0
  26. inspect_ai/_view/www/src/appearance/icons.ts +1 -0
  27. inspect_ai/_view/www/src/components/Card.tsx +6 -4
  28. inspect_ai/_view/www/src/components/LinkButton.module.css +16 -0
  29. inspect_ai/_view/www/src/components/LinkButton.tsx +33 -0
  30. inspect_ai/_view/www/src/components/LiveVirtualList.tsx +1 -1
  31. inspect_ai/_view/www/src/components/MarkdownDiv.tsx +113 -23
  32. inspect_ai/_view/www/src/components/Modal.module.css +38 -0
  33. inspect_ai/_view/www/src/components/Modal.tsx +77 -0
  34. inspect_ai/_view/www/src/plan/DetailStep.module.css +4 -0
  35. inspect_ai/_view/www/src/plan/DetailStep.tsx +6 -3
  36. inspect_ai/_view/www/src/plan/SolverDetailView.module.css +2 -1
  37. inspect_ai/_view/www/src/samples/InlineSampleDisplay.tsx +7 -0
  38. inspect_ai/_view/www/src/samples/SampleDialog.tsx +7 -0
  39. inspect_ai/_view/www/src/samples/SampleDisplay.tsx +11 -34
  40. inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +6 -0
  41. inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +2 -2
  42. inspect_ai/_view/www/src/samples/SamplesTools.tsx +12 -0
  43. inspect_ai/_view/www/src/samples/chat/MessageContent.tsx +2 -0
  44. inspect_ai/_view/www/src/samples/chat/MessageContents.tsx +2 -0
  45. inspect_ai/_view/www/src/samples/chat/messages.ts +3 -1
  46. inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +1 -0
  47. inspect_ai/_view/www/src/samples/descriptor/samplesDescriptor.tsx +9 -3
  48. inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.module.css +3 -3
  49. inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.tsx +1 -1
  50. inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.module.css +4 -4
  51. inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.tsx +10 -11
  52. inspect_ai/_view/www/src/samples/list/SampleFooter.module.css +2 -1
  53. inspect_ai/_view/www/src/samples/list/SampleFooter.tsx +7 -1
  54. inspect_ai/_view/www/src/samples/list/SampleList.tsx +25 -8
  55. inspect_ai/_view/www/src/samples/list/SampleRow.tsx +1 -1
  56. inspect_ai/_view/www/src/samples/scores/SampleScores.tsx +11 -22
  57. inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.module.css +38 -0
  58. inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.tsx +118 -0
  59. inspect_ai/_view/www/src/samples/scores/{SampleScoreView.module.css → SampleScoresView.module.css} +10 -1
  60. inspect_ai/_view/www/src/samples/scores/SampleScoresView.tsx +78 -0
  61. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +3 -3
  62. inspect_ai/_view/www/src/samples/transcript/ToolEventView.tsx +25 -4
  63. inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +29 -2
  64. inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.tsx +0 -1
  65. inspect_ai/_view/www/src/state/hooks.ts +5 -3
  66. inspect_ai/_view/www/src/state/logPolling.ts +5 -1
  67. inspect_ai/_view/www/src/state/logSlice.ts +10 -0
  68. inspect_ai/_view/www/src/state/samplePolling.ts +4 -1
  69. inspect_ai/_view/www/src/state/sampleSlice.ts +13 -0
  70. inspect_ai/_view/www/src/types/log.d.ts +34 -26
  71. inspect_ai/_view/www/src/types/markdown-it-katex.d.ts +21 -0
  72. inspect_ai/_view/www/src/utils/json-worker.ts +79 -12
  73. inspect_ai/_view/www/src/workspace/WorkSpace.tsx +18 -16
  74. inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.module.css +16 -0
  75. inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.tsx +68 -71
  76. inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.module.css +35 -0
  77. inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.tsx +117 -0
  78. inspect_ai/_view/www/src/workspace/navbar/SecondaryBar.tsx +1 -1
  79. inspect_ai/_view/www/src/workspace/sidebar/Sidebar.module.css +3 -2
  80. inspect_ai/_view/www/src/workspace/tabs/SamplesTab.tsx +18 -0
  81. inspect_ai/_view/www/yarn.lock +94 -1
  82. inspect_ai/agent/__init__.py +36 -0
  83. inspect_ai/agent/_agent.py +268 -0
  84. inspect_ai/agent/_as_solver.py +72 -0
  85. inspect_ai/agent/_as_tool.py +122 -0
  86. inspect_ai/{solver → agent}/_bridge/bridge.py +23 -37
  87. inspect_ai/{solver → agent}/_bridge/patch.py +9 -8
  88. inspect_ai/agent/_filter.py +46 -0
  89. inspect_ai/agent/_handoff.py +93 -0
  90. inspect_ai/{solver/_human_agent → agent/_human}/agent.py +11 -12
  91. inspect_ai/{solver/_human_agent → agent/_human}/commands/__init__.py +2 -3
  92. inspect_ai/{solver/_human_agent → agent/_human}/commands/clock.py +3 -1
  93. inspect_ai/{solver/_human_agent → agent/_human}/commands/score.py +5 -5
  94. inspect_ai/{solver/_human_agent → agent/_human}/install.py +6 -3
  95. inspect_ai/{solver/_human_agent → agent/_human}/service.py +7 -3
  96. inspect_ai/{solver/_human_agent → agent/_human}/state.py +5 -5
  97. inspect_ai/agent/_react.py +241 -0
  98. inspect_ai/agent/_run.py +36 -0
  99. inspect_ai/agent/_types.py +81 -0
  100. inspect_ai/log/_log.py +11 -2
  101. inspect_ai/log/_transcript.py +13 -9
  102. inspect_ai/model/__init__.py +7 -1
  103. inspect_ai/model/_call_tools.py +256 -52
  104. inspect_ai/model/_chat_message.py +7 -4
  105. inspect_ai/model/_conversation.py +13 -62
  106. inspect_ai/model/_display.py +85 -0
  107. inspect_ai/model/_model.py +113 -14
  108. inspect_ai/model/_model_output.py +14 -9
  109. inspect_ai/model/_openai.py +16 -4
  110. inspect_ai/model/_openai_computer_use.py +162 -0
  111. inspect_ai/model/_openai_responses.py +319 -165
  112. inspect_ai/model/_providers/anthropic.py +20 -21
  113. inspect_ai/model/_providers/azureai.py +24 -13
  114. inspect_ai/model/_providers/bedrock.py +1 -7
  115. inspect_ai/model/_providers/cloudflare.py +3 -3
  116. inspect_ai/model/_providers/goodfire.py +2 -6
  117. inspect_ai/model/_providers/google.py +11 -10
  118. inspect_ai/model/_providers/groq.py +6 -3
  119. inspect_ai/model/_providers/hf.py +7 -3
  120. inspect_ai/model/_providers/mistral.py +7 -10
  121. inspect_ai/model/_providers/openai.py +47 -17
  122. inspect_ai/model/_providers/openai_o1.py +11 -4
  123. inspect_ai/model/_providers/openai_responses.py +12 -14
  124. inspect_ai/model/_providers/providers.py +2 -2
  125. inspect_ai/model/_providers/together.py +12 -2
  126. inspect_ai/model/_providers/util/chatapi.py +7 -2
  127. inspect_ai/model/_providers/util/hf_handler.py +4 -2
  128. inspect_ai/model/_providers/util/llama31.py +4 -2
  129. inspect_ai/model/_providers/vertex.py +11 -9
  130. inspect_ai/model/_providers/vllm.py +4 -4
  131. inspect_ai/scorer/__init__.py +2 -0
  132. inspect_ai/scorer/_metrics/__init__.py +2 -0
  133. inspect_ai/scorer/_metrics/grouped.py +84 -0
  134. inspect_ai/scorer/_score.py +26 -6
  135. inspect_ai/solver/__init__.py +2 -2
  136. inspect_ai/solver/_basic_agent.py +22 -9
  137. inspect_ai/solver/_bridge.py +31 -0
  138. inspect_ai/solver/_chain.py +20 -12
  139. inspect_ai/solver/_fork.py +5 -1
  140. inspect_ai/solver/_human_agent.py +52 -0
  141. inspect_ai/solver/_prompt.py +3 -1
  142. inspect_ai/solver/_run.py +59 -0
  143. inspect_ai/solver/_solver.py +14 -4
  144. inspect_ai/solver/_task_state.py +5 -3
  145. inspect_ai/tool/_tool_call.py +15 -8
  146. inspect_ai/tool/_tool_def.py +17 -12
  147. inspect_ai/tool/_tool_support_helpers.py +2 -2
  148. inspect_ai/tool/_tool_with.py +14 -11
  149. inspect_ai/tool/_tools/_bash_session.py +11 -2
  150. inspect_ai/tool/_tools/_computer/_common.py +18 -2
  151. inspect_ai/tool/_tools/_computer/_computer.py +18 -2
  152. inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +2 -0
  153. inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +17 -0
  154. inspect_ai/tool/_tools/_think.py +1 -1
  155. inspect_ai/tool/_tools/_web_browser/_web_browser.py +100 -61
  156. inspect_ai/util/__init__.py +2 -0
  157. inspect_ai/util/_anyio.py +27 -0
  158. inspect_ai/util/_sandbox/__init__.py +2 -1
  159. inspect_ai/util/_sandbox/context.py +32 -7
  160. inspect_ai/util/_sandbox/docker/cleanup.py +4 -0
  161. inspect_ai/util/_sandbox/docker/compose.py +2 -2
  162. inspect_ai/util/_sandbox/docker/docker.py +12 -1
  163. inspect_ai/util/_store_model.py +30 -7
  164. inspect_ai/util/_subprocess.py +13 -3
  165. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/METADATA +1 -1
  166. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/RECORD +179 -153
  167. inspect_ai/_view/www/src/samples/scores/SampleScoreView.tsx +0 -167
  168. /inspect_ai/{solver → agent}/_bridge/__init__.py +0 -0
  169. /inspect_ai/{solver/_human_agent → agent/_human}/__init__.py +0 -0
  170. /inspect_ai/{solver/_human_agent → agent/_human}/commands/command.py +0 -0
  171. /inspect_ai/{solver/_human_agent → agent/_human}/commands/instructions.py +0 -0
  172. /inspect_ai/{solver/_human_agent → agent/_human}/commands/note.py +0 -0
  173. /inspect_ai/{solver/_human_agent → agent/_human}/commands/status.py +0 -0
  174. /inspect_ai/{solver/_human_agent → agent/_human}/commands/submit.py +0 -0
  175. /inspect_ai/{solver/_human_agent → agent/_human}/panel.py +0 -0
  176. /inspect_ai/{solver/_human_agent → agent/_human}/view.py +0 -0
  177. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/WHEEL +0 -0
  178. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/entry_points.txt +0 -0
  179. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/licenses/LICENSE +0 -0
  180. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,7 @@
1
1
  from logging import getLogger
2
2
  from typing import Any
3
3
 
4
- from openai import (
5
- AsyncAzureOpenAI,
6
- AsyncOpenAI,
7
- BadRequestError,
8
- )
4
+ from openai import AsyncAzureOpenAI, AsyncOpenAI, BadRequestError
9
5
  from openai._types import NOT_GIVEN
10
6
  from openai.types.responses import Response, ResponseFormatTextJSONSchemaConfigParam
11
7
 
@@ -15,12 +11,10 @@ from inspect_ai.tool import ToolChoice, ToolInfo
15
11
  from .._chat_message import ChatMessage
16
12
  from .._generate_config import GenerateConfig
17
13
  from .._model_call import ModelCall
18
- from .._model_output import (
19
- ModelOutput,
20
- ModelUsage,
21
- )
14
+ from .._model_output import ModelOutput, ModelUsage
22
15
  from .._openai import (
23
16
  OpenAIResponseError,
17
+ is_computer_use_preview,
24
18
  is_gpt,
25
19
  is_o1_mini,
26
20
  is_o1_preview,
@@ -65,12 +59,14 @@ async def generate_responses(
65
59
  )
66
60
 
67
61
  # prepare request (we do this so we can log the ModelCall)
62
+ tool_params = openai_responses_tools(tools, config) if len(tools) > 0 else NOT_GIVEN
68
63
  request = dict(
69
64
  input=await openai_responses_inputs(input, model_name),
70
- tools=openai_responses_tools(tools) if len(tools) > 0 else NOT_GIVEN,
71
- tool_choice=openai_responses_tool_choice(tool_choice)
72
- if len(tools) > 0
65
+ tools=tool_params,
66
+ tool_choice=openai_responses_tool_choice(tool_choice, tool_params)
67
+ if isinstance(tool_params, list) and tool_choice != "auto"
73
68
  else NOT_GIVEN,
69
+ truncation="auto" if is_computer_use_preview(model_name) else NOT_GIVEN,
74
70
  extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
75
71
  **completion_params_responses(model_name, config, len(tools) > 0),
76
72
  )
@@ -89,7 +85,7 @@ async def generate_responses(
89
85
  response = model_response.model_dump()
90
86
 
91
87
  # parse out choices
92
- choices = openai_responses_chat_choices(model_response, tools)
88
+ choices = openai_responses_chat_choices(model_name, model_response, tools)
93
89
 
94
90
  # return output and call
95
91
  return ModelOutput(
@@ -124,7 +120,9 @@ def completion_params_responses(
124
120
  f"OpenAI Responses API does not support the '{param}' parameter.",
125
121
  )
126
122
 
127
- params: dict[str, Any] = dict(model=model_name, store=False)
123
+ params: dict[str, Any] = dict(
124
+ model=model_name, store=is_computer_use_preview(model_name)
125
+ )
128
126
  if config.max_tokens is not None:
129
127
  params["max_output_tokens"] = config.max_tokens
130
128
  if config.frequency_penalty is not None:
@@ -48,7 +48,7 @@ def openai() -> type[ModelAPI]:
48
48
  def anthropic() -> type[ModelAPI]:
49
49
  FEATURE = "Anthropic API"
50
50
  PACKAGE = "anthropic"
51
- MIN_VERSION = "0.47.1"
51
+ MIN_VERSION = "0.49.0"
52
52
 
53
53
  # verify we have the package
54
54
  try:
@@ -278,7 +278,7 @@ def goodfire() -> type[ModelAPI]:
278
278
  def validate_openai_client(feature: str) -> None:
279
279
  FEATURE = feature
280
280
  PACKAGE = "openai"
281
- MIN_VERSION = "1.68.0"
281
+ MIN_VERSION = "1.69.0"
282
282
 
283
283
  # verify we have the package
284
284
  try:
@@ -68,7 +68,9 @@ def chat_choices_from_response_together(
68
68
  logprobs_models.append(Logprobs(content=logprobs_sequence))
69
69
  return [
70
70
  ChatCompletionChoice(
71
- message=chat_message_assistant_from_openai(choice.message, tools),
71
+ message=chat_message_assistant_from_openai(
72
+ response.model, choice.message, tools
73
+ ),
72
74
  stop_reason=as_stop_reason(choice.finish_reason),
73
75
  logprobs=logprobs,
74
76
  )
@@ -116,6 +118,14 @@ class TogetherAIAPI(OpenAIAPI):
116
118
  else:
117
119
  return ex
118
120
 
121
+ @override
122
+ def set_logprobs_params(
123
+ self, params: dict[str, Any], config: GenerateConfig
124
+ ) -> dict[str, Any]:
125
+ if config.logprobs is True:
126
+ params["logprobs"] = 1
127
+ return params
128
+
119
129
  # Together has a slightly different logprobs structure to OpenAI, so we need to remap it.
120
130
  def _chat_choices_from_response(
121
131
  self, response: ChatCompletion, tools: list[ToolInfo]
@@ -228,7 +238,7 @@ class TogetherRESTAPI(ModelAPI):
228
238
  return DEFAULT_MAX_TOKENS
229
239
 
230
240
  def chat_api_handler(self) -> ChatAPIHandler:
231
- return ChatAPIHandler()
241
+ return ChatAPIHandler(self.model_name)
232
242
 
233
243
 
234
244
  def together_choices(
@@ -23,6 +23,9 @@ ChatAPIMessage = dict[Literal["role", "content"], str]
23
23
 
24
24
 
25
25
  class ChatAPIHandler:
26
+ def __init__(self, model: str) -> None:
27
+ self.model = model
28
+
26
29
  def input_with_tools(
27
30
  self, input: list[ChatMessage], tools: list[ToolInfo]
28
31
  ) -> list[ChatMessage]:
@@ -31,7 +34,9 @@ class ChatAPIHandler:
31
34
  def parse_assistant_response(
32
35
  self, response: str, tools: list[ToolInfo]
33
36
  ) -> ChatMessageAssistant:
34
- return ChatMessageAssistant(content=response)
37
+ return ChatMessageAssistant(
38
+ content=response, model=self.model, source="generate"
39
+ )
35
40
 
36
41
  def assistant_message(self, message: ChatMessageAssistant) -> ChatAPIMessage:
37
42
  return {"role": "assistant", "content": message.text}
@@ -48,7 +53,7 @@ class ChatAPIHandler:
48
53
  def chat_api_input(
49
54
  input: list[ChatMessage],
50
55
  tools: list[ToolInfo],
51
- handler: ChatAPIHandler = ChatAPIHandler(),
56
+ handler: ChatAPIHandler,
52
57
  ) -> list[ChatAPIMessage]:
53
58
  # add tools to input
54
59
  if len(tools) > 0:
@@ -50,13 +50,16 @@ class HFHandler(ChatAPIHandler):
50
50
  return ChatMessageAssistant(
51
51
  content=content,
52
52
  tool_calls=tool_calls,
53
+ model=self.model_name,
53
54
  source="generate",
54
55
  )
55
56
 
56
57
  # otherwise this is just an ordinary assistant message
57
58
  else:
58
59
  return ChatMessageAssistant(
59
- content=filter_assistant_header(response), source="generate"
60
+ content=filter_assistant_header(response),
61
+ model=self.model_name,
62
+ source="generate",
60
63
  )
61
64
 
62
65
 
@@ -106,7 +109,6 @@ def parse_tool_call_content(content: str, tools: list[ToolInfo]) -> ToolCall:
106
109
  id="unknown",
107
110
  function="unknown",
108
111
  arguments={},
109
- type="function",
110
112
  parse_error=parse_error,
111
113
  )
112
114
 
@@ -106,13 +106,16 @@ class Llama31Handler(ChatAPIHandler):
106
106
  return ChatMessageAssistant(
107
107
  content=filter_assistant_header(content),
108
108
  tool_calls=tool_calls,
109
+ model=self.model,
109
110
  source="generate",
110
111
  )
111
112
 
112
113
  # otherwise this is just an ordinary assistant message
113
114
  else:
114
115
  return ChatMessageAssistant(
115
- content=filter_assistant_header(response), source="generate"
116
+ content=filter_assistant_header(response),
117
+ model=self.model,
118
+ source="generate",
116
119
  )
117
120
 
118
121
  @override
@@ -184,7 +187,6 @@ def parse_tool_call_content(content: str, tools: list[ToolInfo]) -> ToolCall:
184
187
  id="unknown",
185
188
  function="unknown",
186
189
  arguments={},
187
- type="function",
188
190
  parse_error=parse_error,
189
191
  )
190
192
 
@@ -116,11 +116,6 @@ class VertexAPI(ModelAPI):
116
116
 
117
117
  self.model = GenerativeModel(model_name)
118
118
 
119
- @override
120
- async def close(self) -> None:
121
- # GenerativeModel uses a cached/shared client so there is no 'close'
122
- pass
123
-
124
119
  async def generate(
125
120
  self,
126
121
  input: list[ChatMessage],
@@ -155,7 +150,9 @@ class VertexAPI(ModelAPI):
155
150
  # capture output
156
151
  output = ModelOutput(
157
152
  model=self.model_name,
158
- choices=completion_choices_from_candidates(response.candidates),
153
+ choices=completion_choices_from_candidates(
154
+ self.model_name, response.candidates
155
+ ),
159
156
  usage=ModelUsage(
160
157
  input_tokens=response.usage_metadata.prompt_token_count,
161
158
  output_tokens=response.usage_metadata.candidates_token_count,
@@ -377,7 +374,9 @@ def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
377
374
  return [Tool(function_declarations=declarations)]
378
375
 
379
376
 
380
- def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoice:
377
+ def completion_choice_from_candidate(
378
+ model: str, candidate: Candidate
379
+ ) -> ChatCompletionChoice:
381
380
  # check for completion text
382
381
  content = " ".join(
383
382
  [
@@ -394,7 +393,6 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
394
393
  function_call = MessageToDict(getattr(part.function_call, "_pb"))
395
394
  tool_calls.append(
396
395
  ToolCall(
397
- type="function",
398
396
  id=function_call["name"],
399
397
  function=function_call["name"],
400
398
  arguments=function_call["args"],
@@ -408,6 +406,7 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
408
406
  message=ChatMessageAssistant(
409
407
  content=content,
410
408
  tool_calls=tool_calls if len(tool_calls) > 0 else None,
409
+ model=model,
411
410
  source="generate",
412
411
  ),
413
412
  stop_reason=stop_reason,
@@ -435,11 +434,14 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
435
434
 
436
435
 
437
436
  def completion_choices_from_candidates(
437
+ model: str,
438
438
  candidates: list[Candidate],
439
439
  ) -> list[ChatCompletionChoice]:
440
440
  candidates = copy(candidates)
441
441
  candidates.sort(key=lambda c: c.index)
442
- return [completion_choice_from_candidate(candidate) for candidate in candidates]
442
+ return [
443
+ completion_choice_from_candidate(model, candidate) for candidate in candidates
444
+ ]
443
445
 
444
446
 
445
447
  def candidate_stop_reason(finish_reason: FinishReason) -> StopReason:
@@ -28,7 +28,7 @@ from .._model_output import (
28
28
  StopReason,
29
29
  TopLogprob,
30
30
  )
31
- from .util import chat_api_input
31
+ from .util import ChatAPIHandler, chat_api_input
32
32
 
33
33
  DEFAULT_START_TOKEN = "<|im_start|>"
34
34
  DEFAULT_END_TOKEN = "<|im_end|>"
@@ -137,7 +137,7 @@ class VLLMAPI(ModelAPI):
137
137
  self.tokenizer = self.model.get_tokenizer()
138
138
 
139
139
  @override
140
- async def close(self) -> None:
140
+ def close(self) -> None:
141
141
  self.tokenizer = None
142
142
  self.model = None
143
143
  gc.collect()
@@ -148,7 +148,7 @@ class VLLMAPI(ModelAPI):
148
148
  # handle system message and consecutive user messages
149
149
  messages = simple_input_messages(messages)
150
150
  # convert to chat template input format
151
- chat_messages = chat_api_input(messages, tools)
151
+ chat_messages = chat_api_input(messages, tools, ChatAPIHandler(self.model_name))
152
152
  # apply chat template
153
153
  chat = self.tokenizer.apply_chat_template(
154
154
  chat_messages,
@@ -253,7 +253,7 @@ class VLLMAPI(ModelAPI):
253
253
  choices = [
254
254
  ChatCompletionChoice(
255
255
  message=ChatMessageAssistant(
256
- content=response.output, source="generate"
256
+ content=response.output, model=self.model_name, source="generate"
257
257
  ),
258
258
  stop_reason=response.stop_reason,
259
259
  logprobs=response.logprobs,
@@ -19,6 +19,7 @@ from ._metric import (
19
19
  value_to_float,
20
20
  )
21
21
  from ._metrics.accuracy import accuracy
22
+ from ._metrics.grouped import grouped
22
23
  from ._metrics.mean import mean
23
24
  from ._metrics.std import bootstrap_stderr, std, stderr, var
24
25
  from ._model import model_graded_fact, model_graded_qa
@@ -58,6 +59,7 @@ __all__ = [
58
59
  "std",
59
60
  "stderr",
60
61
  "mean",
62
+ "grouped",
61
63
  "var",
62
64
  "Metric",
63
65
  "MetricProtocol",
@@ -1,10 +1,12 @@
1
1
  from .accuracy import accuracy
2
+ from .grouped import grouped
2
3
  from .mean import mean
3
4
  from .std import bootstrap_stderr, std, stderr, var
4
5
 
5
6
  __all__ = [
6
7
  "accuracy",
7
8
  "mean",
9
+ "grouped",
8
10
  "bootstrap_stderr",
9
11
  "std",
10
12
  "stderr",
@@ -0,0 +1,84 @@
1
+ from typing import Literal, cast
2
+
3
+ import numpy as np
4
+
5
+ from inspect_ai.scorer._metric import (
6
+ Metric,
7
+ MetricProtocol,
8
+ SampleScore,
9
+ Value,
10
+ ValueToFloat,
11
+ metric,
12
+ value_to_float,
13
+ )
14
+
15
+
16
+ @metric
17
+ def grouped(
18
+ metric: Metric,
19
+ group_key: str,
20
+ *,
21
+ all: Literal["samples", "groups"] | Literal[False] = "samples",
22
+ all_label: str = "all",
23
+ value_to_float: ValueToFloat = value_to_float(),
24
+ ) -> Metric:
25
+ """
26
+ Creates a grouped metric that applies the given metric to subgroups of samples.
27
+
28
+ Args:
29
+ metric: The metric to apply to each group of samples.
30
+ group_key: The metadata key used to group samples. Each sample must have this key in its metadata.
31
+ all: How to compute the "all" aggregate score:
32
+ - "samples": Apply the metric to all samples regardless of groups
33
+ - "groups": Calculate the mean of all group scores
34
+ - False: Don't calculate an aggregate score
35
+ all_label: The label for the "all" key in the returned dictionary.
36
+ value_to_float: Function to convert metric values to floats, used when all="groups".
37
+
38
+ Returns:
39
+ A new metric function that returns a dictionary mapping group names to their scores,
40
+ with an optional "all" key for the aggregate score.
41
+ """
42
+
43
+ def grouped_metric(scores: list[SampleScore]) -> Value:
44
+ # Satisfy the type checker that the metric is a MetricProtocol
45
+ metric_protocol = cast(MetricProtocol, metric)
46
+
47
+ # Slice the scores into groups
48
+ scores_dict: dict[str, list[SampleScore]] = {}
49
+ for sample_score in scores:
50
+ if (
51
+ sample_score.sample_metadata is None
52
+ or group_key not in sample_score.sample_metadata
53
+ ):
54
+ raise ValueError(
55
+ f"Sample {sample_score.sample_id} has no {group_key} metadata. To compute a grouped metric each sample metadata must have a value for '{group_key}'"
56
+ )
57
+ group_name = str(sample_score.sample_metadata.get(group_key))
58
+ if group_name not in scores_dict:
59
+ scores_dict[group_name] = []
60
+ scores_dict[group_name].append(sample_score)
61
+
62
+ # Compute the per group metric
63
+ grouped_scores = {
64
+ group_name: metric_protocol(values)
65
+ for group_name, values in scores_dict.items()
66
+ }
67
+
68
+ if not all:
69
+ return cast(Value, grouped_scores)
70
+ else:
71
+ # Compute the all metric
72
+ all_group_metric = None
73
+ if all == "samples":
74
+ # samples means apply the metric to all samples
75
+ all_group_metric = metric_protocol(scores)
76
+ elif all == "groups":
77
+ # group means the overall score is the mean of all the group scores
78
+ all_group_metric = np.mean(
79
+ [value_to_float(val) for val in grouped_scores.values()]
80
+ ).item()
81
+
82
+ return cast(Value, {**grouped_scores, all_label: all_group_metric})
83
+
84
+ return grouped_metric
@@ -1,30 +1,50 @@
1
1
  from contextvars import ContextVar
2
+ from copy import copy
2
3
 
3
- from inspect_ai.solver._task_state import TaskState
4
+ from inspect_ai.model._conversation import ModelConversation
5
+ from inspect_ai.solver._task_state import TaskState, sample_state
4
6
 
5
7
  from ._metric import Score
6
8
  from ._scorer import Scorer
7
9
  from ._target import Target
8
10
 
9
11
 
10
- async def score(state: TaskState) -> list[Score]:
11
- """Score a TaskState.
12
+ async def score(conversation: ModelConversation) -> list[Score]:
13
+ """Score a model conversation.
12
14
 
13
- Score a task state from within a solver.
15
+ Score a model conversation (you may pass `TaskState` or `AgentState`
16
+ as the value for `conversation`)
14
17
 
15
18
  Args:
16
- state (TaskState): `TaskState` to submit for scoring
19
+ conversation: Conversation to submit for scoring.
20
+ Note that both `TaskState` and `AgentState` can be passed
21
+ as the `conversation` parameter.
17
22
 
18
23
  Returns:
19
24
  List of scores (one for each task scorer)
20
25
 
21
26
  Raises:
22
- RuntimerError: If called from outside a task or within
27
+ RuntimeError: If called from outside a task or within
23
28
  a task that does not have a scorer.
24
29
 
25
30
  """
26
31
  from inspect_ai.log._transcript import ScoreEvent, transcript
27
32
 
33
+ # get TaskState (if the `conversation` is a `TaskState` use it directly,
34
+ # otherwise synthesize one)
35
+ if isinstance(conversation, TaskState):
36
+ state = conversation
37
+ else:
38
+ current_state = sample_state()
39
+ if current_state is None:
40
+ raise RuntimeError(
41
+ "The score() function can only be called while executing a task"
42
+ )
43
+ state = copy(current_state)
44
+ state.messages = conversation.messages
45
+ state.output = conversation.output
46
+
47
+ # get current scorers and target
28
48
  scorers = _scorers.get(None)
29
49
  target = _target.get(None)
30
50
  if scorers is None or target is None:
@@ -1,11 +1,11 @@
1
1
  from inspect_ai._util.deprecation import relocated_module_attribute
2
2
 
3
3
  from ._basic_agent import basic_agent
4
- from ._bridge.bridge import bridge
4
+ from ._bridge import bridge
5
5
  from ._chain import chain
6
6
  from ._critique import self_critique
7
7
  from ._fork import fork
8
- from ._human_agent.agent import human_agent
8
+ from ._human_agent import human_agent
9
9
  from ._limit import SampleLimitExceededError
10
10
  from ._multiple_choice import MultipleChoiceTemplate, multiple_choice
11
11
  from ._plan import Plan, plan
@@ -5,8 +5,8 @@ from typing_extensions import TypedDict, Unpack
5
5
 
6
6
  from inspect_ai._util._async import is_callable_coroutine
7
7
  from inspect_ai.model._cache import CachePolicy
8
- from inspect_ai.model._call_tools import call_tools
9
- from inspect_ai.model._chat_message import ChatMessageTool, ChatMessageUser
8
+ from inspect_ai.model._call_tools import execute_tools
9
+ from inspect_ai.model._chat_message import ChatMessage, ChatMessageTool, ChatMessageUser
10
10
  from inspect_ai.model._model import get_model
11
11
  from inspect_ai.scorer._metric import Score, ValueToFloat, value_to_float
12
12
  from inspect_ai.scorer._score import score
@@ -65,6 +65,7 @@ def basic_agent(
65
65
  continue_message: str = DEFAULT_CONTINUE_MESSAGE,
66
66
  submit_name: str = DEFAULT_SUBMIT_NAME,
67
67
  submit_description: str = DEFAULT_SUBMIT_DESCRIPTION,
68
+ submit_append: bool = False,
68
69
  **kwargs: Unpack[BasicAgentDeprecatedArgs],
69
70
  ) -> Solver:
70
71
  """Basic ReAct agent.
@@ -102,6 +103,9 @@ def basic_agent(
102
103
  (defaults to 'submit')
103
104
  submit_description: Description of submit tool (defaults to
104
105
  'Submit an answer for evaluation')
106
+ submit_append: Append the submit tool output to the model completion
107
+ text (defaults to `False`, which means the submission overwrites
108
+ the model completion).
105
109
  **kwargs: Deprecated arguments for backward compatibility.
106
110
 
107
111
  Returns:
@@ -149,9 +153,14 @@ def basic_agent(
149
153
  return solve
150
154
 
151
155
  # helper to extract a submitted answer
152
- def submission(tool_results: list[ChatMessageTool]) -> str | None:
156
+ def submission(tool_results: list[ChatMessage]) -> str | None:
153
157
  return next(
154
- (result.text for result in tool_results if result.function == submit_name),
158
+ (
159
+ result.text
160
+ for result in tool_results
161
+ if isinstance(result, ChatMessageTool)
162
+ and result.function == submit_name
163
+ ),
155
164
  None,
156
165
  )
157
166
 
@@ -189,9 +198,9 @@ def basic_agent(
189
198
 
190
199
  # resolve tools calls (if any)
191
200
  if state.output.message.tool_calls:
192
- # call tool functions
193
- tool_results = await call_tools(
194
- state.output.message,
201
+ # execute tool functions
202
+ tool_results, _ = await execute_tools(
203
+ [state.output.message],
195
204
  state.tools,
196
205
  max_output=max_tool_output,
197
206
  )
@@ -200,8 +209,12 @@ def basic_agent(
200
209
  # was an answer submitted?
201
210
  answer = submission(tool_results)
202
211
  if answer:
203
- # set the output to the answer for scoring
204
- state.output.completion = answer
212
+ if submit_append:
213
+ state.output.completion = (
214
+ f"{state.output.completion}\n\n{answer}".strip()
215
+ )
216
+ else:
217
+ state.output.completion = answer
205
218
 
206
219
  # exit if we are at max_attempts
207
220
  attempts += 1
@@ -0,0 +1,31 @@
1
+ from logging import getLogger
2
+ from typing import Any, Awaitable, Callable
3
+
4
+ from inspect_ai._util.logger import warn_once
5
+ from inspect_ai.agent._as_solver import as_solver
6
+
7
+ from ._solver import Solver, solver
8
+
9
+ logger = getLogger(__name__)
10
+
11
+
12
+ @solver
13
+ def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Solver:
14
+ """Bridge an external agent into an Inspect Solver.
15
+
16
+ See documentation at <https://inspect.ai-safety-institute.org.uk/agent-bridge.html>
17
+
18
+ Args:
19
+ agent: Callable which takes a sample `dict` and returns a result `dict`.
20
+
21
+ Returns:
22
+ Standard Inspect solver.
23
+ """
24
+ from inspect_ai.agent._bridge.bridge import bridge as agent_bridge
25
+
26
+ warn_once(
27
+ logger,
28
+ "The bridge solver is deprecated. Please use the bridge agent from the agents module instead.",
29
+ )
30
+
31
+ return as_solver(agent_bridge(agent))
@@ -1,14 +1,19 @@
1
- from typing import Sequence, overload
1
+ from typing import Sequence, cast, overload
2
2
 
3
3
  from typing_extensions import override
4
4
 
5
+ from inspect_ai.agent._agent import Agent, is_agent
6
+ from inspect_ai.agent._as_solver import as_solver
7
+
5
8
  from ._solver import Generate, Solver, solver
6
9
  from ._task_state import TaskState
7
10
 
8
11
 
9
12
  @solver
10
- def chain(*solvers: Solver | list[Solver]) -> Solver:
11
- """Compose a solver from multiple other solvers.
13
+ def chain(
14
+ *solvers: Solver | Agent | list[Solver] | list[Solver | Agent],
15
+ ) -> Solver:
16
+ """Compose a solver from multiple other solvers and/or agents.
12
17
 
13
18
  Solvers are executed in turn, and a solver step event
14
19
  is added to the transcript for each. If a solver returns
@@ -16,10 +21,10 @@ def chain(*solvers: Solver | list[Solver]) -> Solver:
16
21
  early.
17
22
 
18
23
  Args:
19
- *solvers: One or more solvers or lists of solvers to chain together.
24
+ *solvers: One or more solvers or agents to chain together.
20
25
 
21
26
  Returns:
22
- Solver that executes the passed solvers as a chain.
27
+ Solver that executes the passed solvers and agents as a chain.
23
28
  """
24
29
  # flatten lists and chains
25
30
  all_solvers: list[Solver] = []
@@ -29,17 +34,20 @@ def chain(*solvers: Solver | list[Solver]) -> Solver:
29
34
  return Chain(all_solvers)
30
35
 
31
36
 
32
- def unroll(solver: Solver | list[Solver]) -> list[Solver]:
33
- if isinstance(solver, Solver):
34
- if isinstance(solver, Chain):
35
- return unroll(solver._solvers)
36
- else:
37
- return [solver]
38
- else:
37
+ def unroll(
38
+ solver: Solver | Agent | list[Solver] | list[Solver | Agent],
39
+ ) -> list[Solver]:
40
+ if isinstance(solver, list):
39
41
  unrolled: list[Solver] = []
40
42
  for s in solver:
41
43
  unrolled.extend(unroll(s))
42
44
  return unrolled
45
+ elif is_agent(solver):
46
+ return [as_solver(solver)]
47
+ elif isinstance(solver, Chain):
48
+ return unroll(solver._solvers)
49
+ else:
50
+ return [cast(Solver, solver)]
43
51
 
44
52
 
45
53
  class Chain(Sequence[Solver], Solver):
@@ -52,7 +52,7 @@ async def fork(
52
52
 
53
53
  async def solver_subtask(state: TaskState, solver: Solver) -> TaskState:
54
54
  # get the generate function for the current task
55
- generate = _generate.get(None)
55
+ generate = task_generate()
56
56
  if generate is None:
57
57
  raise RuntimeError("Called fork() outside of a running task.")
58
58
 
@@ -88,4 +88,8 @@ def set_task_generate(generate: Generate) -> None:
88
88
  _generate.set(generate)
89
89
 
90
90
 
91
+ def task_generate() -> Generate | None:
92
+ return _generate.get(None)
93
+
94
+
91
95
  _generate: ContextVar[Generate] = ContextVar("_generate")